Add folder_paths so models can be in multiple paths.
This commit is contained in:
parent
51bbbf8d64
commit
e1a9e26968
|
@ -5,14 +5,12 @@ import model_management
|
|||
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
|
||||
import torch
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
||||
class UpscaleModelLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "models")
|
||||
upscale_model_dir = os.path.join(models_dir, "upscale_models")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ),
|
||||
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
||||
FUNCTION = "load_model"
|
||||
|
@ -20,7 +18,7 @@ class UpscaleModelLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = os.path.join(self.upscale_model_dir, model_name)
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
sd = load_torch_file(model_path)
|
||||
out = model_loading.load_state_dict(sd).eval()
|
||||
return (out, )
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions.add('.safetensors')
|
||||
supported_pt_extensions.add('.safetensors')
|
||||
except:
|
||||
print("Could not import safetensors, safetensors support disabled.")
|
||||
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
|
||||
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
||||
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
|
||||
|
||||
def add_model_folder(folder_name, full_folder_path):
|
||||
global folder_names_and_paths
|
||||
|
||||
|
||||
def recursive_search(directory):
|
||||
result = []
|
||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||
for filepath in file:
|
||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||
return result
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
global folder_names_and_paths
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
full_path = os.path.join(x, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
global folder_names_and_paths
|
||||
output_list = []
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
output_list += filter_files_extensions(recursive_search(x), folders[1])
|
||||
return output_list
|
||||
|
||||
|
60
nodes.py
60
nodes.py
|
@ -23,6 +23,7 @@ import comfy_extras.clip_vision
|
|||
import model_management
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||
try:
|
||||
|
@ -208,31 +209,26 @@ class VAEEncodeForInpaint:
|
|||
|
||||
class CheckpointLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
config_dir = os.path.join(models_dir, "configs")
|
||||
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||
embedding_directory = os.path.join(models_dir, "embeddings")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
|
||||
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
|
||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||
config_path = os.path.join(self.config_dir, config_name)
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
@ -240,7 +236,7 @@ class CheckpointLoaderSimple:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
return out
|
||||
|
||||
|
@ -261,13 +257,11 @@ class CLIPSetLastLayer:
|
|||
return (clip,)
|
||||
|
||||
class LoraLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
lora_dir = os.path.join(models_dir, "loras")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP", ),
|
||||
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
|
@ -277,16 +271,14 @@ class LoraLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
lora_path = os.path.join(self.lora_dir, lora_name)
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
vae_dir = os.path.join(models_dir, "vae")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
|
@ -294,16 +286,14 @@ class VAELoader:
|
|||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
vae_path = os.path.join(self.vae_dir, vae_name)
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
|
@ -311,17 +301,15 @@ class ControlNetLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
class DiffControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
"control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
|
@ -329,7 +317,7 @@ class DiffControlNetLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, model, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
|
||||
return (controlnet,)
|
||||
|
||||
|
@ -378,11 +366,9 @@ class T2IAdapterLoader:
|
|||
return (t2i_adapter,)
|
||||
|
||||
class CLIPLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
@ -390,16 +376,14 @@ class CLIPLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip_vision")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP_VISION",)
|
||||
FUNCTION = "load_clip"
|
||||
|
@ -407,7 +391,7 @@ class CLIPVisionLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
||||
clip_vision = comfy_extras.clip_vision.load(clip_path)
|
||||
return (clip_vision,)
|
||||
|
||||
|
@ -427,11 +411,9 @@ class CLIPVisionEncode:
|
|||
return (output,)
|
||||
|
||||
class StyleModelLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
style_model_dir = os.path.join(models_dir, "style_models")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
|
||||
|
||||
RETURN_TYPES = ("STYLE_MODEL",)
|
||||
FUNCTION = "load_style_model"
|
||||
|
@ -439,7 +421,7 @@ class StyleModelLoader:
|
|||
CATEGORY = "loaders"
|
||||
|
||||
def load_style_model(self, style_model_name):
|
||||
style_model_path = os.path.join(self.style_model_dir, style_model_name)
|
||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
||||
style_model = comfy.sd.load_style_model(style_model_path)
|
||||
return (style_model,)
|
||||
|
||||
|
|
Loading…
Reference in New Issue