add simple error check to model loading (#4950)

This commit is contained in:
Alex "mcmonkey" Goodwin 2024-09-17 16:57:17 +09:00 committed by GitHub
parent 0b7dfa986d
commit 254838f23c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 31 additions and 23 deletions

View File

@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone() model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength) patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None: if patch is not None:

View File

@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker" CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name): def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder() photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data: if "id_encoder" in data:

View File

@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3): def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)

View File

@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models" CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2]) return (out[0], out[3], out[2])

View File

@ -235,6 +235,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
global folder_names_and_paths global folder_names_and_paths

View File

@ -515,7 +515,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name): def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@ -536,7 +536,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name): def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3] return out[:3]
@ -578,7 +578,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
@ -625,7 +625,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return (model, clip) return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
@ -704,11 +704,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc: for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k] sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec: for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k] sd["taesd_decoder.{}".format(k)] = dec[k]
@ -739,7 +739,7 @@ class VAELoader:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
@ -755,7 +755,7 @@ class ControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, control_net_name): def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path) controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
@ -771,7 +771,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,) return (controlnet,)
@ -871,7 +871,7 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
@ -896,7 +896,7 @@ class CLIPLoader:
else: else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
@ -913,8 +913,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type): def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl": if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
@ -936,7 +936,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
@ -966,7 +966,7 @@ class StyleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_style_model(self, style_model_name): def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path) style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,) return (style_model,)
@ -1031,7 +1031,7 @@ class GLIGENLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_gligen(self, gligen_name): def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name) gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path) gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,) return (gligen,)