Make the load checkpoint with config function call the regular one.
I was going to completely remove this function because it is unmaintainable but I think this is the best compromise. The clip skip and v_prediction parts of the configs should still work but not the fp16 vs fp32.
This commit is contained in:
parent
3787b4f246
commit
c61eadf69a
81
comfy/sd.py
81
comfy/sd.py
|
@ -418,6 +418,8 @@ def load_gligen(ckpt_path):
|
|||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||
#TODO: this function is a mess and should be removed eventually
|
||||
if config is None:
|
||||
with open(config_path, 'r') as stream:
|
||||
|
@ -425,81 +427,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
|
||||
fp16 = False
|
||||
if "unet_config" in model_config_params:
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config.pop("use_fp16")
|
||||
if fp16:
|
||||
unet_config["dtype"] = torch.float16
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
noise_aug_config = model_config_params["noise_aug_config"]
|
||||
|
||||
model_type = model_base.ModelType.EPS
|
||||
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
model_type = model_base.ModelType.V_PREDICTION
|
||||
m = model.clone()
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION):
|
||||
pass
|
||||
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
|
||||
model = m
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
layer_idx = clip_config.get("params", {}).get("layer_idx", None)
|
||||
if layer_idx is not None:
|
||||
clip.clip_layer(layer_idx)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = comfy.utils.load_torch_file(ckpt_path)
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
model_config = comfy.supported_models_base.BASE({})
|
||||
|
||||
from . import latent_formats
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = model_detection.convert_config(unet_config)
|
||||
|
||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
else:
|
||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model.set_inpaint()
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd, config=vae_config)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = clip_config.get("params", {})
|
||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
|
|
Loading…
Reference in New Issue