From 32a60a7bacc9f623099ef41dcc1b4a7a2d22f23d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 8 Sep 2024 09:31:41 -0400 Subject: [PATCH] Support onetrainer text encoder Flux lora. --- comfy/lora.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index ad951bba..02c27bf0 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -207,6 +207,7 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False + clip_g_present = False for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) @@ -230,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + clip_g_present = True if clip_l_present: lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k @@ -245,9 +247,15 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): - if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora + t5_index = 1 + if clip_l_present: + t5_index += 1 + if clip_g_present: + t5_index += 1 + l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + lora_key = "lora_te{}_{}".format(t5_index, l_key.replace(".", "_")) key_map[lora_key] = k elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]