Make XL checkpoints save in a more standard format.
This commit is contained in:
parent
b416be7d78
commit
1e0fcc9a65
|
@ -190,12 +190,16 @@ class SDXL(supported_models_base.BASE):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
|
||||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
if k.startswith("clip_l"):
|
if k.startswith("clip_l"):
|
||||||
state_dict_g[k] = state_dict[k]
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|
||||||
|
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
|
||||||
|
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||||
|
for p in pop_keys:
|
||||||
|
if p in state_dict_g:
|
||||||
|
state_dict_g.pop(p)
|
||||||
|
|
||||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
|
Loading…
Reference in New Issue