Add CheckpointSave node to save checkpoints.
The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future.
This commit is contained in:
parent
b72a7a835a
commit
9b93b920be
|
@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
|
|
|
@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import numpy as np
|
||||
from . import utils
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
|
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
|
|||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.v_prediction = v_prediction
|
||||
|
@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module):
|
|||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
if self.get_dtype() == torch.float16:
|
||||
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||
|
|
32
comfy/sd.py
32
comfy/sd.py
|
@ -545,11 +545,11 @@ class CLIP:
|
|||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
try:
|
||||
self.patcher.patch_model()
|
||||
self.patch_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
self.patcher.unpatch_model()
|
||||
self.unpatch_model()
|
||||
except Exception as e:
|
||||
self.patcher.unpatch_model()
|
||||
self.unpatch_model()
|
||||
raise e
|
||||
|
||||
cond_out = cond
|
||||
|
@ -564,6 +564,15 @@ class CLIP:
|
|||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
|
||||
def patch_model(self):
|
||||
self.patcher.patch_model()
|
||||
|
||||
def unpatch_model(self):
|
||||
self.patcher.unpatch_model()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
if config is None:
|
||||
|
@ -665,6 +674,10 @@ class VAE:
|
|||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
|
@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||
print("left over keys:", left_over)
|
||||
|
||||
return (ModelPatcher(model), clip, vae, clipvision)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
try:
|
||||
model.patch_model()
|
||||
clip.patch_model()
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
raise e
|
||||
|
|
|
@ -9,6 +9,8 @@ from . import sdxl_clip
|
|||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
|
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
|
|||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
|
@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
|
@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE):
|
|||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
keys_to_replace = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
for k in state_dict:
|
||||
if k.startswith("clip_l"):
|
||||
state_dict_g[k] = state_dict[k]
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
|
|
|
@ -64,3 +64,15 @@ class BASE:
|
|||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.diffusion_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@ import torch
|
|||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
if safe_load:
|
||||
|
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
|
|||
sd = pl_sd
|
||||
return sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
if metadata is not None:
|
||||
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||
else:
|
||||
safetensors.torch.save_file(sd, ckpt)
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
|
@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
state_dict[k] = state_dict[k].to(dtype)
|
||||
return state_dict
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
|
||||
class ModelMergeSimple:
|
||||
@classmethod
|
||||
|
@ -49,7 +53,43 @@ class ModelMergeBlocks:
|
|||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
|
||||
class CheckpointSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing/model_merging"
|
||||
|
||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||
return {}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
"ModelMergeBlocks": ModelMergeBlocks
|
||||
"ModelMergeBlocks": ModelMergeBlocks,
|
||||
"CheckpointSave": CheckpointSave,
|
||||
}
|
||||
|
|
3
nodes.py
3
nodes.py
|
@ -286,8 +286,7 @@ class SaveLatent:
|
|||
output["latent_tensor"] = samples["samples"]
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
return {}
|
||||
|
||||
|
||||
|
|
|
@ -144,6 +144,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"# ESRGAN upscale model\n",
|
||||
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||
"\n",
|
||||
|
|
|
@ -1468,7 +1468,7 @@ export class ComfyApp {
|
|||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
|
|
|
@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
|
|||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
var slice = file.slice(0, 1024 * 1024 * 4);
|
||||
reader.readAsArrayBuffer(slice);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -545,7 +545,7 @@ export class ComfyUI {
|
|||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png,.latent",
|
||||
accept: ".json,image/png,.latent,.safetensors",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
|
|
Loading…
Reference in New Issue