From dd095efc2c172000a840c7ca54c8a01458f6b2a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Mar 2023 04:32:25 -0400 Subject: [PATCH] Support loha that use cp decomposition. --- comfy/sd.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 714fa66b..d767d867 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -149,8 +149,18 @@ def load_lora(path, to_load): hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x) hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) if hada_w1_a_name in lora.keys(): - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name]) + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -312,7 +322,16 @@ class ModelPatcher: alpha *= v[2] / w1b.shape[0] w2a = v[3] w2b = v[4] - weight += (alpha * torch.mm(w1a.float(), w1b.float()) * torch.mm(w2a.float(), w2b.float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) + else: + m1 = torch.mm(w1a.float(), w1b.float()) + m2 = torch.mm(w2a.float(), w2b.float()) + + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): model_sd = self.model.state_dict()