load_unet -> load_diffusion_model with a model_options argument.

This commit is contained in:
comfyanonymous 2024-08-12 23:18:54 -04:00
parent 5942c17d55
commit a562c17e8a
3 changed files with 19 additions and 8 deletions

View File

@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None: if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path) text_encoder_paths.append(text_encoder2_path)
unet = comfy.sd.load_unet(unet_path) unet = comfy.sd.load_diffusion_model(unet_path)
clip = None clip = None
if output_clip: if output_clip:

View File

@ -590,7 +590,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.info("left over keys in unet: {}".format(left_over)) logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype) model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]

View File

@ -826,14 +826,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name, weight_dtype):
dtype = None model_options = {}
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=dtype) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader: