load_unet -> load_diffusion_model with a model_options argument.
This commit is contained in:
parent
5942c17d55
commit
a562c17e8a
|
@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
|
|
17
comfy/sd.py
17
comfy/sd.py
|
@ -590,7 +590,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
|
@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
|
@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
|
|
8
nodes.py
8
nodes.py
|
@ -826,14 +826,14 @@ class UNETLoader:
|
|||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
dtype = None
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
|
|
Loading…
Reference in New Issue