load_unet -> load_diffusion_model with a model_options argument.
This commit is contained in:
parent
5942c17d55
commit
a562c17e8a
|
@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||||
if text_encoder2_path is not None:
|
if text_encoder2_path is not None:
|
||||||
text_encoder_paths.append(text_encoder2_path)
|
text_encoder_paths.append(text_encoder2_path)
|
||||||
|
|
||||||
unet = comfy.sd.load_unet(unet_path)
|
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
|
17
comfy/sd.py
17
comfy/sd.py
|
@ -590,7 +590,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||||
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
|
@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||||
logging.info("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in unet: {}".format(left_over))
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
|
||||||
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd, dtype=dtype)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_unet(unet_path, dtype=None):
|
||||||
|
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
|
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
|
|
8
nodes.py
8
nodes.py
|
@ -826,14 +826,14 @@ class UNETLoader:
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
def load_unet(self, unet_name, weight_dtype):
|
||||||
dtype = None
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
dtype = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
dtype = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
|
|
Loading…
Reference in New Issue