From e760bf5c4020753a5face9eaeb3ecede9a7b203d Mon Sep 17 00:00:00 2001 From: bymyself Date: Tue, 10 Sep 2024 23:00:07 -0700 Subject: [PATCH] Add content-type filter method to folder_paths (#4054) * Add content-type filter method to folder_paths * Add unit tests * Hardcode webp content-type * Annotate content_types as Literal["image", "video", "audio"] --- comfy_extras/nodes_audio.py | 9 +--- folder_paths.py | 28 ++++++++++ tests-unit/folder_paths_test/__init__.py | 0 .../filter_by_content_types_test.py | 52 +++++++++++++++++++ 4 files changed, 81 insertions(+), 8 deletions(-) create mode 100644 tests-unit/folder_paths_test/__init__.py create mode 100644 tests-unit/folder_paths_test/filter_by_content_types_test.py diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 762b4827..6990b3f9 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -183,17 +183,10 @@ class PreviewAudio(SaveAudio): } class LoadAudio: - SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif') - @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - files = [ - f for f in os.listdir(input_dir) - if (os.path.isfile(os.path.join(input_dir, f)) - and f.endswith(LoadAudio.SUPPORTED_FORMATS) - ) - ] + files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) return {"required": {"audio": (sorted(files), {"audio_upload": True})}} CATEGORY = "audio" diff --git a/folder_paths.py b/folder_paths.py index b154448f..d7a7e0c3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -2,7 +2,9 @@ from __future__ import annotations import os import time +import mimetypes import logging +from typing import Set, List, Dict, Tuple, Literal from collections.abc import Collection supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} @@ -44,6 +46,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} +extension_mimetypes_cache = { + "webp" : "image", +} + def map_legacy(folder_name: str) -> str: legacy = {"unet": "diffusion_models"} return legacy.get(folder_name, folder_name) @@ -89,6 +95,28 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None +def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]: + """ + Example: + files = os.listdir(folder_paths.get_input_directory()) + filter_files_content_types(files, ["image", "audio", "video"]) + """ + global extension_mimetypes_cache + result = [] + for file in files: + extension = file.split('.')[-1] + if extension not in extension_mimetypes_cache: + mime_type, _ = mimetypes.guess_type(file, strict=False) + if not mime_type: + continue + content_type = mime_type.split('/')[0] + extension_mimetypes_cache[extension] = content_type + else: + content_type = extension_mimetypes_cache[extension] + + if content_type in content_types: + result.append(file) + return result # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir diff --git a/tests-unit/folder_paths_test/__init__.py b/tests-unit/folder_paths_test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-unit/folder_paths_test/filter_by_content_types_test.py b/tests-unit/folder_paths_test/filter_by_content_types_test.py new file mode 100644 index 00000000..5941bfa9 --- /dev/null +++ b/tests-unit/folder_paths_test/filter_by_content_types_test.py @@ -0,0 +1,52 @@ +import pytest +import os +import tempfile +from folder_paths import filter_files_content_types + +@pytest.fixture(scope="module") +def file_extensions(): + return { + 'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'], + 'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'], + 'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] + } + + +@pytest.fixture(scope="module") +def mock_dir(file_extensions): + with tempfile.TemporaryDirectory() as directory: + for content_type, extensions in file_extensions.items(): + for extension in extensions: + with open(f"{directory}/sample_{content_type}.{extension}", "w") as f: + f.write(f"Sample {content_type} file in {extension} format") + yield directory + + +def test_categorizes_all_correctly(mock_dir, file_extensions): + files = os.listdir(mock_dir) + for content_type, extensions in file_extensions.items(): + filtered_files = filter_files_content_types(files, [content_type]) + for extension in extensions: + assert f"sample_{content_type}.{extension}" in filtered_files + + +def test_categorizes_all_uniquely(mock_dir, file_extensions): + files = os.listdir(mock_dir) + for content_type, extensions in file_extensions.items(): + filtered_files = filter_files_content_types(files, [content_type]) + assert len(filtered_files) == len(extensions) + + +def test_handles_bad_extensions(): + files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"] + assert filter_files_content_types(files, ["image", "audio", "video"]) == [] + + +def test_handles_no_extension(): + files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"] + assert filter_files_content_types(files, ["image", "audio", "video"]) == [] + + +def test_handles_no_files(): + files = [] + assert filter_files_content_types(files, ["image", "audio", "video"]) == [] \ No newline at end of file