Fix issue with full diffusers SD3 loras.
This commit is contained in:
parent
0d6a57938e
commit
028a583bef
|
@ -210,16 +210,19 @@ class ModelPatcher:
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
offset = None
|
offset = None
|
||||||
|
function = None
|
||||||
if isinstance(k, str):
|
if isinstance(k, str):
|
||||||
key = k
|
key = k
|
||||||
else:
|
else:
|
||||||
offset = k[1]
|
offset = k[1]
|
||||||
key = k[0]
|
key = k[0]
|
||||||
|
if len(k) > 2:
|
||||||
|
function = k[2]
|
||||||
|
|
||||||
if key in model_sd:
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(key, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model, offset))
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||||
self.patches[key] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
@ -347,6 +350,9 @@ class ModelPatcher:
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
offset = p[3]
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
old_weight = None
|
old_weight = None
|
||||||
if offset is not None:
|
if offset is not None:
|
||||||
|
@ -371,7 +377,7 @@ class ModelPatcher:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
|
@ -389,9 +395,9 @@ class ModelPatcher:
|
||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
else:
|
else:
|
||||||
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "lokr":
|
elif patch_type == "lokr":
|
||||||
|
@ -435,9 +441,9 @@ class ModelPatcher:
|
||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
else:
|
else:
|
||||||
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "loha":
|
elif patch_type == "loha":
|
||||||
|
@ -472,9 +478,9 @@ class ModelPatcher:
|
||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
else:
|
else:
|
||||||
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
|
@ -493,9 +499,9 @@ class ModelPatcher:
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
else:
|
else:
|
||||||
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue