Support multiple paths for embeddings.
This commit is contained in:
parent
51d6427ddf
commit
50099bcd96
|
@ -168,19 +168,28 @@ def unescape_important(text):
|
|||
return text
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
embed_path = os.path.join(embedding_directory, embedding_name)
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
valid_file = None
|
||||
for x in extensions:
|
||||
t = embed_path + x
|
||||
if os.path.isfile(t):
|
||||
valid_file = t
|
||||
break
|
||||
if valid_file is None:
|
||||
return None
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
for x in extensions:
|
||||
t = embed_path + x
|
||||
if os.path.isfile(t):
|
||||
valid_file = t
|
||||
break
|
||||
else:
|
||||
embed_path = valid_file
|
||||
valid_file = embed_path
|
||||
if valid_file is not None:
|
||||
break
|
||||
|
||||
if valid_file is None:
|
||||
return None
|
||||
|
||||
embed_path = valid_file
|
||||
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
|
|
|
@ -22,7 +22,7 @@ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt
|
|||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
|
@ -33,6 +33,8 @@ def add_model_folder_path(folder_name, full_folder_path):
|
|||
if folder_name in folder_names_and_paths:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
|
||||
def get_folder_paths(folder_name):
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory):
|
||||
result = []
|
||||
|
|
7
nodes.py
7
nodes.py
|
@ -188,9 +188,6 @@ class VAEEncodeForInpaint:
|
|||
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
|
||||
|
||||
class CheckpointLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
embedding_directory = os.path.join(models_dir, "embeddings")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
|
||||
|
@ -203,7 +200,7 @@ class CheckpointLoader:
|
|||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
@classmethod
|
||||
|
@ -217,7 +214,7 @@ class CheckpointLoaderSimple:
|
|||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out
|
||||
|
||||
class CLIPSetLastLayer:
|
||||
|
|
Loading…
Reference in New Issue