diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 63704f13..44bfb8c0 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -9,7 +9,6 @@ class InternalRoutes: The top level web router for internal routes: /internal/* The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only. Check README.md for more information. - ''' def __init__(self): self.routes: web.RouteTableDef = web.RouteTableDef() diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py deleted file mode 100644 index b7ac1625..00000000 --- a/model_filemanager/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# model_manager/__init__.py -from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py deleted file mode 100644 index 6722b6e1..00000000 --- a/model_filemanager/download_models.py +++ /dev/null @@ -1,234 +0,0 @@ -#NOTE: This was an experiment and WILL BE REMOVED -from __future__ import annotations -import aiohttp -import os -import traceback -import logging -from folder_paths import folder_names_and_paths, get_folder_paths -import re -from typing import Callable, Any, Optional, Awaitable, Dict -from enum import Enum -import time -from dataclasses import dataclass - - -class DownloadStatusType(Enum): - PENDING = "pending" - IN_PROGRESS = "in_progress" - COMPLETED = "completed" - ERROR = "error" - - -@dataclass -class DownloadModelStatus(): - status: str - progress_percentage: float - message: str - already_existed: bool = False - - def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): - self.status = status.value # Store the string value of the Enum - self.progress_percentage = progress_percentage - self.message = message - self.already_existed = already_existed - - def to_dict(self) -> Dict[str, Any]: - return { - "status": self.status, - "progress_percentage": self.progress_percentage, - "message": self.message, - "already_existed": self.already_existed - } - - -async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], - model_name: str, - model_url: str, - model_directory: str, - folder_path: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - progress_interval: float = 1.0) -> DownloadModelStatus: - """ - Download a model file from a given URL into the models directory. - - Args: - model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): - A function that makes an HTTP request. This makes it easier to mock in unit tests. - model_name (str): - The name of the model file to be downloaded. This will be the filename on disk. - model_url (str): - The URL from which to download the model. - model_directory (str): - The subdirectory within the main models directory where the model - should be saved (e.g., 'checkpoints', 'loras', etc.). - progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): - An asynchronous function to call with progress updates. - folder_path (str); - Path to which model folder should be used as the root. - - Returns: - DownloadModelStatus: The result of the download operation. - """ - if not validate_filename(model_name): - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid model name", - False - ) - - if not model_directory in folder_names_and_paths: - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", - False - ) - - if not folder_path in get_folder_paths(model_directory): - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.", - False - ) - - file_path = create_model_path(model_name, folder_path) - existing_file = await check_file_exists(file_path, model_name, progress_callback) - if existing_file: - return existing_file - - try: - logging.info(f"Downloading {model_name} from {model_url}") - status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) - await progress_callback(model_name, status) - - response = await model_download_request(model_url) - if response.status != 200: - error_message = f"Failed to download {model_name}. Status code: {response.status}" - logging.error(error_message) - status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(model_name, status) - return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - - return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval) - - except Exception as e: - logging.error(f"Error in downloading model: {e}") - return await handle_download_error(e, model_name, progress_callback) - - -def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]: - os.makedirs(folder_path, exist_ok=True) - file_path = os.path.join(folder_path, model_name) - - # Ensure the resulting path is still within the base directory - abs_file_path = os.path.abspath(file_path) - abs_base_dir = os.path.abspath(folder_path) - if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: - raise Exception(f"Invalid model directory: {folder_path}/{model_name}") - - return file_path - - -async def check_file_exists(file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]] - ) -> Optional[DownloadModelStatus]: - if os.path.exists(file_path): - status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) - await progress_callback(model_name, status) - return status - return None - - -async def track_download_progress(response: aiohttp.ClientResponse, - file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - interval: float = 1.0) -> DownloadModelStatus: - try: - total_size = int(response.headers.get('Content-Length', 0)) - downloaded = 0 - last_update_time = time.time() - - async def update_progress(): - nonlocal last_update_time - progress = (downloaded / total_size) * 100 if total_size > 0 else 0 - status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) - await progress_callback(model_name, status) - last_update_time = time.time() - - temp_file_path = file_path + '.tmp' - with open(temp_file_path, 'wb') as f: - chunk_iterator = response.content.iter_chunked(8192) - while True: - try: - chunk = await chunk_iterator.__anext__() - except StopAsyncIteration: - break - f.write(chunk) - downloaded += len(chunk) - - if time.time() - last_update_time >= interval: - await update_progress() - - os.rename(temp_file_path, file_path) - - await update_progress() - - logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") - status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) - await progress_callback(model_name, status) - - return status - except Exception as e: - logging.error(f"Error in track_download_progress: {e}") - logging.error(traceback.format_exc()) - return await handle_download_error(e, model_name, progress_callback) - - -async def handle_download_error(e: Exception, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any] - ) -> DownloadModelStatus: - error_message = f"Error downloading {model_name}: {str(e)}" - status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(model_name, status) - return status - - -def validate_filename(filename: str)-> bool: - """ - Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. - - Args: - filename (str): The filename to validate - - Returns: - bool: True if the filename is valid, False otherwise - """ - if not filename.lower().endswith(('.sft', '.safetensors')): - return False - - # Check if the filename is empty, None, or just whitespace - if not filename or not filename.strip(): - return False - - # Check for any directory traversal attempts or invalid characters - if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']): - return False - - # Check if the filename starts with a dot (hidden file) - if filename.startswith('.'): - return False - - # Use a whitelist of allowed characters - if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename): - return False - - # Ensure the filename isn't too long - if len(filename) > 255: - return False - - return True diff --git a/server.py b/server.py index ada6d90c..d096cc3c 100644 --- a/server.py +++ b/server.py @@ -29,7 +29,6 @@ import comfy.model_management import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager -from model_filemanager import download_model, DownloadModelStatus from typing import Optional from api_server.routes.internal.internal_routes import InternalRoutes @@ -676,36 +675,6 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) - - # Internal route. Should not be depended upon and is subject to change at any time. - # TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket. - # NOTE: This was an experiment and WILL BE REMOVED - @routes.post("/internal/models/download") - async def download_handler(request): - async def report_progress(filename: str, status: DownloadModelStatus): - payload = status.to_dict() - payload['download_path'] = filename - await self.send_json("download_progress", payload) - - data = await request.json() - url = data.get('url') - model_directory = data.get('model_directory') - folder_path = data.get('folder_path') - model_filename = data.get('model_filename') - progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. - - if not url or not model_directory or not model_filename or not folder_path: - return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) - - session = self.client_session - if session is None: - logging.error("Client session is not initialized") - return web.Response(status=500) - - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval)) - await task - - return web.json_response(task.result().to_dict()) async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py deleted file mode 100644 index 128dfeb9..00000000 --- a/tests-unit/prompt_server_test/download_models_test.py +++ /dev/null @@ -1,337 +0,0 @@ -import pytest -import tempfile -import aiohttp -from aiohttp import ClientResponse -import itertools -import os -from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename -import folder_paths - -@pytest.fixture -def temp_dir(): - with tempfile.TemporaryDirectory() as tmpdirname: - yield tmpdirname - -class AsyncIteratorMock: - """ - A mock class that simulates an asynchronous iterator. - This is used to mimic the behavior of aiohttp's content iterator. - """ - def __init__(self, seq): - # Convert the input sequence into an iterator - self.iter = iter(seq) - - def __aiter__(self): - # This method is called when 'async for' is used - return self - - async def __anext__(self): - # This method is called for each iteration in an 'async for' loop - try: - return next(self.iter) - except StopIteration: - # This is the asynchronous equivalent of StopIteration - raise StopAsyncIteration - -class ContentMock: - """ - A mock class that simulates the content attribute of an aiohttp ClientResponse. - This class provides the iter_chunked method which returns an async iterator of chunks. - """ - def __init__(self, chunks): - # Store the chunks that will be returned by the iterator - self.chunks = chunks - - def iter_chunked(self, chunk_size): - # This method mimics aiohttp's content.iter_chunked() - # For simplicity in testing, we ignore chunk_size and just return our predefined chunks - return AsyncIteratorMock(self.chunks) - -@pytest.mark.asyncio -async def test_download_model_success(temp_dir): - mock_response = AsyncMock(spec=aiohttp.ClientResponse) - mock_response.status = 200 - mock_response.headers = {'Content-Length': '1000'} - # Create a mock for content that returns an async iterator directly - chunks = [b'a' * 500, b'b' * 300, b'c' * 200] - mock_response.content = ContentMock(chunks) - - mock_make_request = AsyncMock(return_value=mock_response) - mock_progress_callback = AsyncMock() - - time_values = itertools.count(0, 0.1) - - fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} - - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ - patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.folder_names_and_paths', fake_paths), \ - patch('time.time', side_effect=time_values): # Simulate time passing - - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - 'checkpoints', - temp_dir, - mock_progress_callback - ) - - # Assert the result - assert isinstance(result, DownloadModelStatus) - assert result.message == 'Successfully downloaded model.sft' - assert result.status == 'completed' - assert result.already_existed is False - - # Check progress callback calls - assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion - - # Check initial call - mock_progress_callback.assert_any_call( - 'model.sft', - DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) - ) - - # Check final call - mock_progress_callback.assert_any_call( - 'model.sft', - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) - ) - - mock_file_path = os.path.join(temp_dir, 'model.sft') - assert os.path.exists(mock_file_path) - with open(mock_file_path, 'rb') as mock_file: - assert mock_file.read() == b''.join(chunks) - os.remove(mock_file_path) - - # Verify request was made - mock_make_request.assert_called_once_with('http://example.com/model.sft') - -@pytest.mark.asyncio -async def test_download_model_url_request_failure(temp_dir): - # Mock dependencies - mock_response = AsyncMock(spec=ClientResponse) - mock_response.status = 404 # Simulate a "Not Found" error - mock_get = AsyncMock(return_value=mock_response) - mock_progress_callback = AsyncMock() - - fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} - - # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ - patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.folder_names_and_paths', fake_paths): - # Call the function - result = await download_model( - mock_get, - 'model.safetensors', - 'http://example.com/model.safetensors', - 'checkpoints', - temp_dir, - mock_progress_callback - ) - - # Assert the expected behavior - assert isinstance(result, DownloadModelStatus) - assert result.status == 'error' - assert result.message == 'Failed to download model.safetensors. Status code: 404' - assert result.already_existed is False - - # Check that progress_callback was called with the correct arguments - mock_progress_callback.assert_any_call( - 'model.safetensors', - DownloadModelStatus( - status=DownloadStatusType.PENDING, - progress_percentage=0, - message='Starting download of model.safetensors', - already_existed=False - ) - ) - mock_progress_callback.assert_called_with( - 'model.safetensors', - DownloadModelStatus( - status=DownloadStatusType.ERROR, - progress_percentage=0, - message='Failed to download model.safetensors. Status code: 404', - already_existed=False - ) - ) - - # Verify that the get method was called with the correct URL - mock_get.assert_called_once_with('http://example.com/model.safetensors') - -@pytest.mark.asyncio -async def test_download_model_invalid_model_subdirectory(): - mock_make_request = AsyncMock() - mock_progress_callback = AsyncMock() - - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - '../bad_path', - '../bad_path', - mock_progress_callback - ) - - # Assert the result - assert isinstance(result, DownloadModelStatus) - assert result.message.startswith('Invalid or unrecognized model directory') - assert result.status == 'error' - assert result.already_existed is False - -@pytest.mark.asyncio -async def test_download_model_invalid_folder_path(): - mock_make_request = AsyncMock() - mock_progress_callback = AsyncMock() - - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - 'checkpoints', - 'invalid_path', - mock_progress_callback - ) - - # Assert the result - assert isinstance(result, DownloadModelStatus) - assert result.message.startswith("Invalid folder path") - assert result.status == 'error' - assert result.already_existed is False - -def test_create_model_path(tmp_path, monkeypatch): - model_name = "model.safetensors" - folder_path = os.path.join(tmp_path, "mock_dir") - - file_path = create_model_path(model_name, folder_path) - - assert file_path == os.path.join(folder_path, "model.safetensors") - assert os.path.exists(os.path.dirname(file_path)) - - with pytest.raises(Exception, match="Invalid model directory"): - create_model_path("../path_traversal.safetensors", folder_path) - - with pytest.raises(Exception, match="Invalid model directory"): - create_model_path("/etc/some_root_path", folder_path) - - -@pytest.mark.asyncio -async def test_check_file_exists_when_file_exists(tmp_path): - file_path = tmp_path / "existing_model.sft" - file_path.touch() # Create an empty file - - mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback) - - assert result is not None - assert result.status == "completed" - assert result.message == "existing_model.sft already exists" - assert result.already_existed is True - - mock_callback.assert_called_once_with( - "existing_model.sft", - DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) - ) - -@pytest.mark.asyncio -async def test_check_file_exists_when_file_does_not_exist(tmp_path): - file_path = tmp_path / "non_existing_model.sft" - - mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback) - - assert result is None - mock_callback.assert_not_called() - -@pytest.mark.asyncio -async def test_track_download_progress_no_content_length(temp_dir): - mock_response = AsyncMock(spec=aiohttp.ClientResponse) - mock_response.headers = {} # No Content-Length header - chunks = [b'a' * 500, b'b' * 500] - mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) - - mock_callback = AsyncMock() - - full_path = os.path.join(temp_dir, 'model.sft') - - result = await track_download_progress( - mock_response, full_path, 'model.sft', - mock_callback, interval=0.1 - ) - - assert result.status == "completed" - - assert os.path.exists(full_path) - with open(full_path, 'rb') as f: - assert f.read() == b''.join(chunks) - os.remove(full_path) - - # Check that progress was reported even without knowing the total size - mock_callback.assert_any_call( - 'model.sft', - DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) - ) - -@pytest.mark.asyncio -async def test_track_download_progress_interval(temp_dir): - mock_response = AsyncMock(spec=aiohttp.ClientResponse) - mock_response.headers = {'Content-Length': '1000'} - chunks = [b'a' * 100] * 10 - mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) - - mock_callback = AsyncMock() - mock_open = MagicMock(return_value=MagicMock()) - - # Create a mock time function that returns incremental float values - mock_time = MagicMock() - mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks - - full_path = os.path.join(temp_dir, 'model.sft') - - with patch('time.time', mock_time): - await track_download_progress( - mock_response, full_path, 'model.sft', - mock_callback, interval=1.0 - ) - - assert os.path.exists(full_path) - with open(full_path, 'rb') as f: - assert f.read() == b''.join(chunks) - os.remove(full_path) - - # Assert that progress was updated at least 3 times (start, at least one interval, and end) - assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" - - # Verify the first and last calls - first_call = mock_callback.call_args_list[0] - assert first_call[0][1].status == "in_progress" - # Allow for some initial progress, but it should be less than 50% - assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%" - - last_call = mock_callback.call_args_list[-1] - assert last_call[0][1].status == "completed" - assert last_call[0][1].progress_percentage == 100 - -@pytest.mark.parametrize("filename, expected", [ - ("valid_model.safetensors", True), - ("valid_model.sft", True), - ("valid model.safetensors", True), # Test with space - ("UPPERCASE_MODEL.SAFETENSORS", True), - ("model_with.multiple.dots.pt", False), - ("", False), # Empty string - ("../../../etc/passwd", False), # Path traversal attempt - ("/etc/passwd", False), # Absolute path - ("\\windows\\system32\\config\\sam", False), # Windows path - (".hidden_file.pt", False), # Hidden file - ("invalid.ckpt", False), # Invalid character - ("invalid?.ckpt", False), # Another invalid character - ("very" * 100 + ".safetensors", False), # Too long filename - ("\nmodel_with_newline.pt", False), # Newline character - ("model_with_emoji😊.pt", False), # Emoji in filename -]) -def test_validate_filename(filename, expected): - assert validate_filename(filename) == expected