2023-06-26 16:21:07 +00:00
|
|
|
import comfy.sd
|
|
|
|
import comfy.utils
|
|
|
|
import folder_paths
|
|
|
|
import json
|
|
|
|
import os
|
2023-06-20 23:17:03 +00:00
|
|
|
|
|
|
|
class ModelMergeSimple:
|
|
|
|
@classmethod
|
|
|
|
def INPUT_TYPES(s):
|
|
|
|
return {"required": { "model1": ("MODEL",),
|
|
|
|
"model2": ("MODEL",),
|
|
|
|
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
|
|
FUNCTION = "merge"
|
|
|
|
|
2023-06-30 18:51:44 +00:00
|
|
|
CATEGORY = "advanced/model_merging"
|
2023-06-20 23:17:03 +00:00
|
|
|
|
|
|
|
def merge(self, model1, model2, ratio):
|
|
|
|
m = model1.clone()
|
2023-06-20 23:37:43 +00:00
|
|
|
sd = model2.model_state_dict("diffusion_model.")
|
2023-06-20 23:17:03 +00:00
|
|
|
for k in sd:
|
|
|
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
|
|
|
return (m, )
|
|
|
|
|
|
|
|
class ModelMergeBlocks:
|
|
|
|
@classmethod
|
|
|
|
def INPUT_TYPES(s):
|
|
|
|
return {"required": { "model1": ("MODEL",),
|
|
|
|
"model2": ("MODEL",),
|
|
|
|
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
|
|
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
|
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
|
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
|
|
FUNCTION = "merge"
|
|
|
|
|
2023-06-30 18:51:44 +00:00
|
|
|
CATEGORY = "advanced/model_merging"
|
2023-06-20 23:17:03 +00:00
|
|
|
|
|
|
|
def merge(self, model1, model2, **kwargs):
|
|
|
|
m = model1.clone()
|
2023-06-20 23:37:43 +00:00
|
|
|
sd = model2.model_state_dict("diffusion_model.")
|
2023-06-20 23:17:03 +00:00
|
|
|
default_ratio = next(iter(kwargs.values()))
|
|
|
|
|
|
|
|
for k in sd:
|
|
|
|
ratio = default_ratio
|
|
|
|
k_unet = k[len("diffusion_model."):]
|
|
|
|
|
2023-07-04 04:51:17 +00:00
|
|
|
last_arg_size = 0
|
2023-06-20 23:17:03 +00:00
|
|
|
for arg in kwargs:
|
2023-07-04 04:51:17 +00:00
|
|
|
if k_unet.startswith(arg) and last_arg_size < len(arg):
|
2023-06-20 23:17:03 +00:00
|
|
|
ratio = kwargs[arg]
|
2023-07-04 04:51:17 +00:00
|
|
|
last_arg_size = len(arg)
|
2023-06-20 23:17:03 +00:00
|
|
|
|
|
|
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
|
|
|
return (m, )
|
|
|
|
|
2023-06-26 16:21:07 +00:00
|
|
|
class CheckpointSave:
|
|
|
|
def __init__(self):
|
|
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def INPUT_TYPES(s):
|
|
|
|
return {"required": { "model": ("MODEL",),
|
|
|
|
"clip": ("CLIP",),
|
|
|
|
"vae": ("VAE",),
|
|
|
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
|
|
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
|
|
RETURN_TYPES = ()
|
|
|
|
FUNCTION = "save"
|
|
|
|
OUTPUT_NODE = True
|
|
|
|
|
2023-06-30 18:51:44 +00:00
|
|
|
CATEGORY = "advanced/model_merging"
|
2023-06-26 16:21:07 +00:00
|
|
|
|
|
|
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
|
|
prompt_info = ""
|
|
|
|
if prompt is not None:
|
|
|
|
prompt_info = json.dumps(prompt)
|
|
|
|
|
|
|
|
metadata = {"prompt": prompt_info}
|
|
|
|
if extra_pnginfo is not None:
|
|
|
|
for x in extra_pnginfo:
|
|
|
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
|
|
|
|
|
|
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
|
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
|
|
|
|
|
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
2023-06-20 23:17:03 +00:00
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
|
|
"ModelMergeSimple": ModelMergeSimple,
|
2023-06-26 16:21:07 +00:00
|
|
|
"ModelMergeBlocks": ModelMergeBlocks,
|
|
|
|
"CheckpointSave": CheckpointSave,
|
2023-06-20 23:17:03 +00:00
|
|
|
}
|