diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 1eab54d4..9688cbd5 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict): +def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} for k, v in text_enc_dict.items(): + if not k.startswith(prefix): + continue if ( k.endswith(".self_attn.q_proj.weight") or k.endswith(".self_attn.k_proj.weight") diff --git a/comfy/model_base.py b/comfy/model_base.py index 923c4348..e4c9391d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from . import utils class BaseModel(torch.nn.Module): def __init__(self, model_config, v_prediction=False): @@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format + self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) + def state_dict_for_saving(self, clip_state_dict, vae_state_dict): + clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + unet_state_dict = self.diffusion_model.state_dict() + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) + vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: + clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) + vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, v_prediction=True): diff --git a/comfy/sd.py b/comfy/sd.py index dbfbdbe3..21d7b8a5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -545,11 +545,11 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: - self.patcher.patch_model() + self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.patcher.unpatch_model() + self.unpatch_model() except Exception as e: - self.patcher.unpatch_model() + self.unpatch_model() raise e cond_out = cond @@ -564,6 +564,15 @@ class CLIP: def load_sd(self, sd): return self.cond_stage_model.load_sd(sd) + def get_sd(self): + return self.cond_stage_model.state_dict() + + def patch_model(self): + self.patcher.patch_model() + + def unpatch_model(self): + self.patcher.unpatch_model() + class VAE: def __init__(self, ckpt_path=None, device=None, config=None): if config is None: @@ -665,6 +674,10 @@ class VAE: self.first_stage_model = self.first_stage_model.cpu() return samples + def get_sd(self): + return self.first_stage_model.state_dict() + + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) return (ModelPatcher(model), clip, vae, clipvision) + +def save_checkpoint(output_path, model, clip, vae, metadata=None): + try: + model.patch_model() + clip.patch_model() + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata) + model.unpatch_model() + clip.unpatch_model() + except Exception as e: + model.unpatch_model() + clip.unpatch_model() + raise e diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 51da9456..6b17b089 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -9,6 +9,8 @@ from . import sdxl_clip from . import supported_models_base from . import latent_formats +from . import diffusers_convert + class SD15(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + replace_prefix[""] = "cond_stage_model.model." + state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) + return state_dict + def clip_target(self): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) @@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + replace_prefix["clip_g"] = "conditioner.embedders.0.model" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) @@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + keys_to_replace = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + for k in state_dict: + if k.startswith("clip_l"): + state_dict_g[k] = state_dict[k] + + replace_prefix["clip_g"] = "conditioner.embedders.1.model" + replace_prefix["clip_l"] = "conditioner.embedders.0" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3312a99d..0b0235ca 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -64,3 +64,15 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "cond_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model.diffusion_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_vae_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "first_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + diff --git a/comfy/utils.py b/comfy/utils.py index 7a7f1fa1..b6434905 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,10 +2,10 @@ import torch import math import struct import comfy.checkpoint_pickle +import safetensors.torch def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): - import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: if safe_load: @@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False): sd = pl_sd return sd +def save_torch_file(sd, ckpt, metadata=None): + if metadata is not None: + safetensors.torch.save_file(sd, ckpt, metadata=metadata) + else: + safetensors.torch.save_file(sd, ckpt) + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def convert_sd_to(state_dict, dtype): + keys = list(state_dict.keys()) + for k in keys: + state_dict[k] = state_dict[k].to(dtype) + return state_dict + def safetensors_header(safetensors_path, max_size=100*1024*1024): with open(safetensors_path, "rb") as f: header = f.read(8) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 52b73f70..4f71b203 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,4 +1,8 @@ - +import comfy.sd +import comfy.utils +import folder_paths +import json +import os class ModelMergeSimple: @classmethod @@ -49,7 +53,43 @@ class ModelMergeBlocks: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) +class CheckpointSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "_for_testing/model_merging" + + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + return {} + + NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks + "ModelMergeBlocks": ModelMergeBlocks, + "CheckpointSave": CheckpointSave, } diff --git a/nodes.py b/nodes.py index 7280d788..3c096009 100644 --- a/nodes.py +++ b/nodes.py @@ -286,8 +286,7 @@ class SaveLatent: output["latent_tensor"] = samples["samples"] output["latent_format_version_0"] = torch.tensor([]) - safetensors.torch.save_file(output, file, metadata=metadata) - + comfy.utils.save_torch_file(output, file, metadata=metadata) return {} diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c5a209ee..61c277bf 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -144,6 +144,7 @@ "\n", "\n", "# ESRGAN upscale model\n", + "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "\n", diff --git a/web/scripts/app.js b/web/scripts/app.js index 4e83c40a..09310c7f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1468,7 +1468,7 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); - } else if (file.name?.endsWith(".latent")) { + } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { const info = await getLatentMetadata(file); if (info.workflow) { this.loadGraphData(JSON.parse(info.workflow)); diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 977b5ac2..c5293dfa 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -55,11 +55,12 @@ export function getLatentMetadata(file) { const dataView = new DataView(safetensorsData.buffer); let header_size = dataView.getUint32(0, true); let offset = 8; - let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size))); r(header.__metadata__); }; - reader.readAsArrayBuffer(file); + var slice = file.slice(0, 1024 * 1024 * 4); + reader.readAsArrayBuffer(slice); }); } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 99e9123a..12fda127 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -545,7 +545,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent", + accept: ".json,image/png,.latent,.safetensors", style: {display: "none"}, parent: document.body, onchange: () => {