Add a node to merge CLIP models.

This commit is contained in:
comfyanonymous 2023-07-14 02:37:30 -04:00
parent 907c9fbf0d
commit 91ed2815d5
2 changed files with 27 additions and 2 deletions

View File

@ -479,8 +479,8 @@ class CLIP:
def load_from_state_dict(self, sd):
self.cond_stage_model.load_sd(sd)
def add_patches(self, patches, strength=1.0):
return self.patcher.add_patches(patches, strength)
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
@ -514,6 +514,9 @@ class CLIP:
def unpatch_model(self):
self.patcher.unpatch_model()
def get_key_patches(self):
return self.patcher.get_key_patches()
class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
if config is None:

View File

@ -23,6 +23,27 @@ class ModelMergeSimple:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
class CLIPMergeSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
class ModelMergeBlocks:
@classmethod
def INPUT_TYPES(s):
@ -94,4 +115,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
}