diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 98b888a1..56e63a75 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire if text_encoder2_path is not None: text_encoder_paths.append(text_encoder2_path) - unet = comfy.sd.load_unet(unet_path) + unet = comfy.sd.load_diffusion_model(unet_path) clip = None if output_clip: diff --git a/comfy/sd.py b/comfy/sd.py index c8a2f086..13909d67 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -590,7 +590,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format + dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) @@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") @@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for logging.info("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) -def load_unet(unet_path, dtype=None): + +def load_diffusion_model(unet_path, model_options={}): sd = comfy.utils.load_torch_file(unet_path) - model = load_unet_state_dict(sd, dtype=dtype) + model = load_diffusion_model_state_dict(sd, model_options=model_options) if model is None: logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model +def load_unet(unet_path, dtype=None): + print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") + return load_diffusion_model(unet_path, model_options={"dtype": dtype}) + +def load_unet_state_dict(sd, dtype=None): + print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") + return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) + def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] diff --git a/nodes.py b/nodes.py index e296597c..525b28d8 100644 --- a/nodes.py +++ b/nodes.py @@ -826,14 +826,14 @@ class UNETLoader: CATEGORY = "advanced/loaders" def load_unet(self, unet_name, weight_dtype): - dtype = None + model_options = {} if weight_dtype == "fp8_e4m3fn": - dtype = torch.float8_e4m3fn + model_options["dtype"] = torch.float8_e4m3fn elif weight_dtype == "fp8_e5m2": - dtype = torch.float8_e5m2 + model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path("unet", unet_name) - model = comfy.sd.load_unet(unet_path, dtype=dtype) + model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) class CLIPLoader: