diff --git a/comfy/sd.py b/comfy/sd.py index 76eaa5b5..526dc531 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -206,7 +206,7 @@ class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0): self.size = size self.model = model - self.patches = [] + self.patches = {} self.backup = {} self.model_options = {"transformer_options":{}} self.model_size() @@ -227,7 +227,10 @@ class ModelPatcher: def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) - n.patches = self.patches[:] + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys return n @@ -295,12 +298,28 @@ class ModelPatcher: return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = {} + p = set() for k in patches: if k in self.model_keys: - p[k] = patches[k] - self.patches += [(strength_patch, p, strength_model)] - return p.keys() + p.add(k) + current_patches = self.patches.get(k, []) + current_patches.append((strength_patch, patches[k], strength_model)) + self.patches[k] = current_patches + + return list(p) + + def get_key_patches(self, filter_prefix=None): + model_sd = self.model_state_dict() + p = {} + for k in model_sd: + if filter_prefix is not None: + if not k.startswith(filter_prefix): + continue + if k in self.patches: + p[k] = [model_sd[k]] + self.patches[k] + else: + p[k] = (model_sd[k],) + return p def model_state_dict(self, filter_prefix=None): sd = self.model.state_dict() @@ -313,85 +332,93 @@ class ModelPatcher: def patch_model(self): model_sd = self.model_state_dict() - for p in self.patches: - for k in p[1]: - v = p[1][k] - key = k - if key not in model_sd: - print("could not patch. key doesn't exist in model:", k) - continue + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", k) + continue - weight = model_sd[key] - if key not in self.backup: - self.backup[key] = weight.clone() + weight = model_sd[key] - alpha = p[0] - strength_model = p[2] + if key not in self.backup: + self.backup[key] = weight.clone() - if strength_model != 1.0: - weight *= strength_model - - if len(v) == 1: - w1 = v[0] - if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += alpha * w1.type(weight.dtype).to(weight.device) - elif len(v) == 4: #lora/locon - mat1 = v[0] - mat2 = v[1] - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) - elif len(v) == 8: #lokr - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(w2_a.float(), w2_b.float()) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha *= v[2] / dim - - weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) - else: #loha - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha *= v[2] / w1b.shape[0] - w2a = v[3] - w2b = v[4] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) - else: - m1 = torch.mm(w1a.float(), w1b.float()) - m2 = torch.mm(w2a.float(), w2b.float()) - - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) + weight[:] = self.calculate_weight(self.patches[key], weight.clone(), key) return self.model + + def calculate_weight(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key), ) + + if len(v) == 1: + w1 = v[0] + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon + mat1 = v[0] + mat2 = v[1] + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) + weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float(), w2_b.float()) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) + else: #loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) + else: + m1 = torch.mm(w1a.float(), w1b.float()) + m2 = torch.mm(w2a.float(), w2b.float()) + + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) + return weight + def unpatch_model(self): model_sd = self.model_state_dict() keys = list(self.backup.keys()) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 9bbb84da..eae9b6fd 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -18,9 +18,9 @@ class ModelMergeSimple: def merge(self, model1, model2, ratio): m = model1.clone() - sd = model2.model_state_dict("diffusion_model.") - for k in sd: - m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) class ModelMergeBlocks: @@ -39,10 +39,10 @@ class ModelMergeBlocks: def merge(self, model1, model2, **kwargs): m = model1.clone() - sd = model2.model_state_dict("diffusion_model.") + kp = model2.get_key_patches("diffusion_model.") default_ratio = next(iter(kwargs.values())) - for k in sd: + for k in kp: ratio = default_ratio k_unet = k[len("diffusion_model."):] @@ -52,7 +52,7 @@ class ModelMergeBlocks: ratio = kwargs[arg] last_arg_size = len(arg) - m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) class CheckpointSave: