diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 44b82795..bf878776 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,16 +210,19 @@ class ModelPatcher: model_sd = self.model.state_dict() for k in patches: offset = None + function = None if isinstance(k, str): key = k else: offset = k[1] key = k[0] + if len(k) > 2: + function = k[2] if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset)) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) self.patches[key] = current_patches self.patches_uuid = uuid.uuid4() @@ -347,6 +350,9 @@ class ModelPatcher: v = p[1] strength_model = p[2] offset = p[3] + function = p[4] + if function is None: + function = lambda a: a old_weight = None if offset is not None: @@ -371,7 +377,7 @@ class ModelPatcher: if w1.shape != weight.shape: logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) + weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) @@ -389,9 +395,9 @@ class ModelPatcher: try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) else: - weight += ((strength * alpha) * lora_diff).type(weight.dtype) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -435,9 +441,9 @@ class ModelPatcher: try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) else: - weight += ((strength * alpha) * lora_diff).type(weight.dtype) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": @@ -472,9 +478,9 @@ class ModelPatcher: try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) else: - weight += ((strength * alpha) * lora_diff).type(weight.dtype) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": @@ -493,9 +499,9 @@ class ModelPatcher: try: lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) else: - weight += ((strength * alpha) * lora_diff).type(weight.dtype) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: