Support onetrainer text encoder Flux lora.

This commit is contained in:
comfyanonymous 2024-09-08 09:31:41 -04:00
parent bb52934ba4
commit 32a60a7bac
1 changed files with 10 additions and 2 deletions

View File

@ -207,6 +207,7 @@ def model_lora_keys_clip(model, key_map={}):
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@ -230,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
clip_g_present = True
if clip_l_present: if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k key_map[lora_key] = k
@ -245,9 +247,15 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
t5_index = 1
if clip_l_present:
t5_index += 1
if clip_g_present:
t5_index += 1
l_key = k[len("t5xxl.transformer."):-len(".weight")] l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) lora_key = "lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))
key_map[lora_key] = k key_map[lora_key] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]