Support loras in diffusers format.
This commit is contained in:
parent
5a90d3cea5
commit
c5d7593ccf
24
comfy/sd.py
24
comfy/sd.py
|
@ -70,13 +70,22 @@ def load_lora(lora, to_load):
|
||||||
alpha = lora[alpha_name].item()
|
alpha = lora[alpha_name].item()
|
||||||
loaded_keys.add(alpha_name)
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
A_name = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
A_name = None
|
||||||
|
|
||||||
if A_name in lora.keys():
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
if mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
|
@ -202,6 +211,11 @@ def model_lora_keys_unet(model, key_map={}):
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
|
|
||||||
|
diffusers_lora_key = "unet.{}".format(k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||||
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
|
key_map[diffusers_lora_key] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def set_attr(obj, attr, value):
|
||||||
|
|
Loading…
Reference in New Issue