diff --git a/comfy/sd.py b/comfy/sd.py index 2be8edef..41ce18c8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -567,7 +567,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (model_patcher, clip, vae, clipvision) -def load_unet_state_dict(sd): #load unet in diffusers or regular format +def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format #Allow loading unets from checkpoint files diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) @@ -576,7 +576,6 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format sd = temp_sd parameters = comfy.utils.calculate_parameters(sd) - unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, "") @@ -603,7 +602,11 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + if dtype is None: + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + else: + unet_dtype = dtype + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") @@ -614,9 +617,9 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format logging.info("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) -def load_unet(unet_path): +def load_unet(unet_path, dtype=None): sd = comfy.utils.load_torch_file(unet_path) - model = load_unet_state_dict(sd) + model = load_unet_state_dict(sd, dtype=dtype) if model is None: logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) diff --git a/nodes.py b/nodes.py index 93d24ae5..fbd0c6ce 100644 --- a/nodes.py +++ b/nodes.py @@ -818,15 +818,17 @@ class UNETLoader: @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - def load_unet(self, unet_name): + def load_unet(self, unet_name, weight_dtype): + weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype] unet_path = folder_paths.get_full_path("unet", unet_name) - model = comfy.sd.load_unet(unet_path) + model = comfy.sd.load_unet(unet_path, dtype=weight_dtype) return (model,) class CLIPLoader: