New CheckpointLoaderSimple to load checkpoints without a config.
This commit is contained in:
parent
c1f5855ac1
commit
94bb0375b0
|
@ -81,7 +81,7 @@ class DDPM(torch.nn.Module):
|
|||
super().__init__()
|
||||
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
|
||||
self.parameterization = parameterization
|
||||
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
||||
# print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
||||
self.cond_stage_model = None
|
||||
self.clip_denoised = clip_denoised
|
||||
self.log_every_t = log_every_t
|
||||
|
@ -522,8 +522,8 @@ class LatentDiffusion(DDPM):
|
|||
"""main class"""
|
||||
|
||||
def __init__(self,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
first_stage_config={},
|
||||
cond_stage_config={},
|
||||
num_timesteps_cond=None,
|
||||
cond_stage_key="image",
|
||||
cond_stage_trainable=False,
|
||||
|
@ -562,8 +562,6 @@ class LatentDiffusion(DDPM):
|
|||
|
||||
# self.instantiate_first_stage(first_stage_config)
|
||||
# self.instantiate_cond_stage(cond_stage_config)
|
||||
self.first_stage_config = first_stage_config
|
||||
self.cond_stage_config = cond_stage_config
|
||||
|
||||
self.cond_stage_forward = cond_stage_forward
|
||||
self.clip_denoised = False
|
||||
|
|
104
comfy/sd.py
104
comfy/sd.py
|
@ -317,9 +317,7 @@ class VAE:
|
|||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples):
|
||||
tile_x = tile_y = 64
|
||||
overlap = 8
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8):
|
||||
model_management.unload_model()
|
||||
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
|
@ -656,3 +654,103 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||
sd = load_torch_file(ckpt_path)
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
return (ModelPatcher(model), clip, vae)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, fp16=False, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
sd = load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE()
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip_config = {}
|
||||
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
else:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
sd_config = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.012,
|
||||
"num_timesteps_cond": 1,
|
||||
"log_every_t": 200,
|
||||
"timesteps": 1000,
|
||||
"first_stage_key": "jpg",
|
||||
"cond_stage_key": "txt",
|
||||
"image_size": 64,
|
||||
"channels": 4,
|
||||
"cond_stage_trainable": False,
|
||||
"monitor": "val/loss_simple_ema",
|
||||
"scale_factor": 0.18215,
|
||||
"use_ema": False,
|
||||
}
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": True,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"attention_resolutions": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
],
|
||||
"num_res_blocks": 2,
|
||||
"channel_mult": [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
4
|
||||
],
|
||||
"use_spatial_transformer": True,
|
||||
"transformer_depth": 1,
|
||||
"legacy": False
|
||||
}
|
||||
|
||||
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
|
||||
unet_config['use_linear_in_transformer'] = True
|
||||
|
||||
unet_config["use_fp16"] = fp16
|
||||
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
|
||||
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
|
||||
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
|
||||
if unet_config["in_channels"] > 4: #inpainting model
|
||||
sd_config["conditioning_key"] = "hybrid"
|
||||
sd_config["finetune_keys"] = None
|
||||
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
else:
|
||||
sd_config["conditioning_key"] = "crossattn"
|
||||
|
||||
if unet_config["context_dim"] == 1024:
|
||||
unet_config["num_head_channels"] = 64 #SD2.x
|
||||
else:
|
||||
unet_config["num_heads"] = 8 #SD1.x
|
||||
|
||||
|
||||
model = instantiate_from_config(model_config)
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||
cond = torch.zeros((1, 2, unet_config["context_dim"]), device="cpu")
|
||||
x_in = torch.rand((1, unet_config["in_channels"], 8, 8), device="cpu", generator=torch.manual_seed(1))
|
||||
out = model.apply_model(x_in, torch.tensor([999], device="cpu"), cond)
|
||||
if out.mean() < -0.6: #mean of eps should be ~0 and mean of v prediction should be ~-1
|
||||
model.parameterization = 'v'
|
||||
|
||||
return (ModelPatcher(model), clip, vae)
|
||||
|
|
23
nodes.py
23
nodes.py
|
@ -202,6 +202,28 @@ class CheckpointLoader:
|
|||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
|
||||
"type": (["fp16", "fp32"],),
|
||||
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, type, stop_at_clip_layer, output_vae=True, output_clip=True):
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, type=="fp16", output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
if out[1] is not None:
|
||||
out[1].clip_layer(stop_at_clip_layer)
|
||||
return out
|
||||
|
||||
class LoraLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
lora_dir = os.path.join(models_dir, "loras")
|
||||
|
@ -837,6 +859,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"DiffControlNetLoader": DiffControlNetLoader,
|
||||
"T2IAdapterLoader": T2IAdapterLoader,
|
||||
"VAEDecodeTiled": VAEDecodeTiled,
|
||||
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||
}
|
||||
|
||||
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
||||
|
|
Loading…
Reference in New Issue