add recursive_search, swap relevant os.listdirs

This commit is contained in:
BazettFraga 2023-02-09 01:22:33 +01:00
parent 3fd87cbd21
commit 81082045c2
1 changed files with 13 additions and 5 deletions

View File

@ -26,6 +26,14 @@ try:
except:
print("Could not import safetensors, safetensors support disabled.")
def recursive_search(directory):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we remove the first character to remove the path separator.
result.append(os.path.join(root, filepath).replace(directory,'')[1:])
return result
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
@ -119,8 +127,8 @@ class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (filter_files_extensions(os.listdir(s.config_dir), '.yaml'), ),
"ckpt_name": (filter_files_extensions(os.listdir(s.ckpt_dir), supported_ckpt_extensions), )}}
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -138,7 +146,7 @@ class LoraLoader:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
@ -157,7 +165,7 @@ class VAELoader:
vae_dir = os.path.join(models_dir, "vae")
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(os.listdir(s.vae_dir), supported_pt_extensions), )}}
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -174,7 +182,7 @@ class CLIPLoader:
clip_dir = os.path.join(models_dir, "clip")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(os.listdir(s.clip_dir), supported_pt_extensions), ),
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
}}
RETURN_TYPES = ("CLIP",)