Support simple diffusers Flux loras.
This commit is contained in:
parent
7afa985fba
commit
78e133d041
|
@ -288,4 +288,12 @@ def model_lora_keys_unet(model, key_map={}):
|
|||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
||||
|
||||
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
|
||||
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
|
|
@ -415,6 +415,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
|||
|
||||
return key_map
|
||||
|
||||
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("depth", 0)
|
||||
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
||||
hidden_size = mmdit_config.get("hidden_size", 0)
|
||||
|
||||
key_map = {}
|
||||
for index in range(n_double_layers):
|
||||
prefix_from = "transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
for index in range(n_single_layers):
|
||||
prefix_from = "single_transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.linear1.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
||||
|
||||
block_map = {#TODO
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
MAP_BASIC = { #TODO
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
|
Loading…
Reference in New Issue