unet -> diffusion_models.

This commit is contained in:
comfyanonymous 2024-08-17 21:28:36 -04:00
parent bb222ceddb
commit 4f7a3cb6fb
5 changed files with 18 additions and 4 deletions

View File

@ -670,10 +670,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if clip is not None: if clip is not None:
load_models.append(clip.load_model()) load_models.append(clip.load_model())
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()
model_management.load_models_gpu(load_models, force_patch_weights=True) model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys: for k in extra_keys:
sd[k] = extra_keys[k] sd[k] = extra_keys[k]

View File

@ -17,7 +17,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
@ -44,6 +44,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try: try:
os.makedirs(input_directory) os.makedirs(input_directory)
@ -128,12 +132,14 @@ def exists_annotated_filepath(name) -> bool:
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path) folder_names_and_paths[folder_name][0].append(full_folder_path)
else: else:
folder_names_and_paths[folder_name] = ([full_folder_path], set()) folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name: str) -> list[str]: def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:] return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]: def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
@ -180,6 +186,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths: if folder_name not in folder_names_and_paths:
return None return None
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@ -194,6 +201,7 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@ -208,6 +216,7 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global filename_list_cache global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
@ -227,6 +236,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
return out return out
def get_filename_list(folder_name: str) -> list[str]: def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name) out = cached_filename_list_(folder_name)
if out is None: if out is None:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)

View File

@ -242,6 +242,7 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)

View File

@ -855,7 +855,7 @@ class ControlNetApplyAdvanced:
class UNETLoader: class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
@ -870,7 +870,7 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)