From 3e52e0364cf81764f58e5aa4f53f0b702f4d4a81 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Tue, 13 Aug 2024 12:48:52 -0700 Subject: [PATCH] Add model downloading endpoint. (#4248) * Add model downloading endpoint. * Move client session init to async function. * Break up large function. * Send "download_progress" as websocket event. * Fixed * Fixed. * Use async mock. * Move server set up to right before run call. * Validate that model subdirectory cannot contain relative paths. * Add download_model test checking for invalid paths. * Remove DS_Store. * Consolidate DownloadStatus and DownloadModelResult * Add progress_interval as an optional parameter. * Use tuple type from annotations. * Use pydantic. * Update comment. * Revert "Use pydantic." This reverts commit 7461e8eb0073add315c65c6f5e361f0891bffc7d. * Add new line. * Add newline EOF. * Validate model filename as well. * Add comment to not reply on internal. * Restrict downloading to safetensor files only. --- main.py | 1 + model_filemanager/__init__.py | 2 + model_filemanager/download_models.py | 240 +++++++++++++ server.py | 35 +- tests-unit/prompt_server_test/__init__.py | 0 .../download_models_test.py | 321 ++++++++++++++++++ tests-unit/requirements.txt | 2 + 7 files changed, 599 insertions(+), 2 deletions(-) create mode 100644 model_filemanager/__init__.py create mode 100644 model_filemanager/download_models.py create mode 100644 tests-unit/prompt_server_test/__init__.py create mode 100644 tests-unit/prompt_server_test/download_models_test.py diff --git a/main.py b/main.py index 479643b8..b878b3e0 100644 --- a/main.py +++ b/main.py @@ -261,6 +261,7 @@ if __name__ == "__main__": call_on_start = startup_server try: + loop.run_until_complete(server.setup()) loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: logging.info("\nStopped server") diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py new file mode 100644 index 00000000..e318351c --- /dev/null +++ b/model_filemanager/__init__.py @@ -0,0 +1,2 @@ +# model_manager/__init__.py +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py new file mode 100644 index 00000000..712d5932 --- /dev/null +++ b/model_filemanager/download_models.py @@ -0,0 +1,240 @@ +from __future__ import annotations +import aiohttp +import os +import traceback +import logging +from folder_paths import models_dir +import re +from typing import Callable, Any, Optional, Awaitable, Dict +from enum import Enum +import time +from dataclasses import dataclass + + +class DownloadStatusType(Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + ERROR = "error" + +@dataclass +class DownloadModelStatus(): + status: str + progress_percentage: float + message: str + already_existed: bool = False + + def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): + self.status = status.value # Store the string value of the Enum + self.progress_percentage = progress_percentage + self.message = message + self.already_existed = already_existed + + def to_dict(self) -> Dict[str, Any]: + return { + "status": self.status, + "progress_percentage": self.progress_percentage, + "message": self.message, + "already_existed": self.already_existed + } + +async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], + model_name: str, + model_url: str, + model_sub_directory: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + progress_interval: float = 1.0) -> DownloadModelStatus: + """ + Download a model file from a given URL into the models directory. + + Args: + model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): + A function that makes an HTTP request. This makes it easier to mock in unit tests. + model_name (str): + The name of the model file to be downloaded. This will be the filename on disk. + model_url (str): + The URL from which to download the model. + model_sub_directory (str): + The subdirectory within the main models directory where the model + should be saved (e.g., 'checkpoints', 'loras', etc.). + progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): + An asynchronous function to call with progress updates. + + Returns: + DownloadModelStatus: The result of the download operation. + """ + if not validate_model_subdirectory(model_sub_directory): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid model subdirectory", + False + ) + + if not validate_filename(model_name): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid model name", + False + ) + + file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) + existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) + if existing_file: + return existing_file + + try: + status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) + await progress_callback(relative_path, status) + + response = await model_download_request(model_url) + if response.status != 200: + error_message = f"Failed to download {model_name}. Status code: {response.status}" + logging.error(error_message) + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + await progress_callback(relative_path, status) + return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + + return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) + + except Exception as e: + logging.error(f"Error in downloading model: {e}") + return await handle_download_error(e, model_name, progress_callback, relative_path) + + +def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: + full_model_dir = os.path.join(models_base_dir, model_directory) + os.makedirs(full_model_dir, exist_ok=True) + file_path = os.path.join(full_model_dir, model_name) + + # Ensure the resulting path is still within the base directory + abs_file_path = os.path.abspath(file_path) + abs_base_dir = os.path.abspath(str(models_base_dir)) + if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: + raise Exception(f"Invalid model directory: {model_directory}/{model_name}") + + + relative_path = '/'.join([model_directory, model_name]) + return file_path, relative_path + +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str) -> Optional[DownloadModelStatus]: + if os.path.exists(file_path): + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) + await progress_callback(relative_path, status) + return status + return None + + +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str, + interval: float = 1.0) -> DownloadModelStatus: + try: + total_size = int(response.headers.get('Content-Length', 0)) + downloaded = 0 + last_update_time = time.time() + + async def update_progress(): + nonlocal last_update_time + progress = (downloaded / total_size) * 100 if total_size > 0 else 0 + status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) + await progress_callback(relative_path, status) + last_update_time = time.time() + + with open(file_path, 'wb') as f: + chunk_iterator = response.content.iter_chunked(8192) + while True: + try: + chunk = await chunk_iterator.__anext__() + except StopAsyncIteration: + break + f.write(chunk) + downloaded += len(chunk) + + if time.time() - last_update_time >= interval: + await update_progress() + + await update_progress() + + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) + await progress_callback(relative_path, status) + + return status + except Exception as e: + logging.error(f"Error in track_download_progress: {e}") + logging.error(traceback.format_exc()) + return await handle_download_error(e, model_name, progress_callback, relative_path) + +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any], + relative_path: str) -> DownloadModelStatus: + error_message = f"Error downloading {model_name}: {str(e)}" + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) + await progress_callback(relative_path, status) + return status + +def validate_model_subdirectory(model_subdirectory: str) -> bool: + """ + Validate that the model subdirectory is safe to install into. + Must not contain relative paths, nested paths or special characters + other than underscores and hyphens. + + Args: + model_subdirectory (str): The subdirectory for the specific model type. + + Returns: + bool: True if the subdirectory is safe, False otherwise. + """ + if len(model_subdirectory) > 50: + return False + + if '..' in model_subdirectory or '/' in model_subdirectory: + return False + + if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): + return False + + return True + +def validate_filename(filename: str)-> bool: + """ + Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. + + Args: + filename (str): The filename to validate + + Returns: + bool: True if the filename is valid, False otherwise + """ + if not filename.lower().endswith(('.sft', '.safetensors')): + return False + + # Check if the filename is empty, None, or just whitespace + if not filename or not filename.strip(): + return False + + # Check for any directory traversal attempts or invalid characters + if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']): + return False + + # Check if the filename starts with a dot (hidden file) + if filename.startswith('.'): + return False + + # Use a whitelist of allowed characters + if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename): + return False + + # Ensure the filename isn't too long + if len(filename) > 255: + return False + + return True diff --git a/server.py b/server.py index 1c07f978..0c382016 100644 --- a/server.py +++ b/server.py @@ -12,7 +12,6 @@ import json import glob import struct import ssl -import hashlib from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO @@ -28,7 +27,8 @@ import comfy.model_management import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager - +from model_filemanager import download_model, DownloadModelStatus +from typing import Optional class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -76,6 +76,7 @@ class PromptServer(): self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() + self.client_session:Optional[aiohttp.ClientSession] = None self.number = 0 middlewares = [cache_control] @@ -559,6 +560,36 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) + + # Internal route. Should not be depended upon and is subject to change at any time. + # TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket. + @routes.post("/internal/models/download") + async def download_handler(request): + async def report_progress(filename: str, status: DownloadModelStatus): + await self.send_json("download_progress", status.to_dict()) + + data = await request.json() + url = data.get('url') + model_directory = data.get('model_directory') + model_filename = data.get('model_filename') + progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. + + if not url or not model_directory or not model_filename: + return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) + + session = self.client_session + if session is None: + logging.error("Client session is not initialized") + return web.Response(status=500) + + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) + await task + + return web.json_response(task.result().to_dict()) + + async def setup(self): + timeout = aiohttp.ClientTimeout(total=None) # no timeout + self.client_session = aiohttp.ClientSession(timeout=timeout) def add_routes(self): self.user_manager.add_routes(self.routes) diff --git a/tests-unit/prompt_server_test/__init__.py b/tests-unit/prompt_server_test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py new file mode 100644 index 00000000..66150a46 --- /dev/null +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -0,0 +1,321 @@ +import pytest +import aiohttp +from aiohttp import ClientResponse +import itertools +import os +from unittest.mock import AsyncMock, patch, MagicMock +from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename + +class AsyncIteratorMock: + """ + A mock class that simulates an asynchronous iterator. + This is used to mimic the behavior of aiohttp's content iterator. + """ + def __init__(self, seq): + # Convert the input sequence into an iterator + self.iter = iter(seq) + + def __aiter__(self): + # This method is called when 'async for' is used + return self + + async def __anext__(self): + # This method is called for each iteration in an 'async for' loop + try: + return next(self.iter) + except StopIteration: + # This is the asynchronous equivalent of StopIteration + raise StopAsyncIteration + +class ContentMock: + """ + A mock class that simulates the content attribute of an aiohttp ClientResponse. + This class provides the iter_chunked method which returns an async iterator of chunks. + """ + def __init__(self, chunks): + # Store the chunks that will be returned by the iterator + self.chunks = chunks + + def iter_chunked(self, chunk_size): + # This method mimics aiohttp's content.iter_chunked() + # For simplicity in testing, we ignore chunk_size and just return our predefined chunks + return AsyncIteratorMock(self.chunks) + +@pytest.mark.asyncio +async def test_download_model_success(): + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {'Content-Length': '1000'} + # Create a mock for content that returns an async iterator directly + chunks = [b'a' * 500, b'b' * 300, b'c' * 200] + mock_response.content = ContentMock(chunks) + + mock_make_request = AsyncMock(return_value=mock_response) + mock_progress_callback = AsyncMock() + + # Mock file operations + mock_open = MagicMock() + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + time_values = itertools.count(0, 0.1) + + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ + patch('model_filemanager.check_file_exists', return_value=None), \ + patch('builtins.open', mock_open), \ + patch('time.time', side_effect=time_values): # Simulate time passing + + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + 'checkpoints', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelStatus) + assert result.message == 'Successfully downloaded model.sft' + assert result.status == 'completed' + assert result.already_existed is False + + # Check progress callback calls + assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion + + # Check initial call + mock_progress_callback.assert_any_call( + 'checkpoints/model.sft', + DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) + ) + + # Check final call + mock_progress_callback.assert_any_call( + 'checkpoints/model.sft', + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) + ) + + # Verify file writing + mock_file.write.assert_any_call(b'a' * 500) + mock_file.write.assert_any_call(b'b' * 300) + mock_file.write.assert_any_call(b'c' * 200) + + # Verify request was made + mock_make_request.assert_called_once_with('http://example.com/model.sft') + +@pytest.mark.asyncio +async def test_download_model_url_request_failure(): + # Mock dependencies + mock_response = AsyncMock(spec=ClientResponse) + mock_response.status = 404 # Simulate a "Not Found" error + mock_get = AsyncMock(return_value=mock_response) + mock_progress_callback = AsyncMock() + + # Mock the create_model_path function + with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): + # Mock the check_file_exists function to return None (file doesn't exist) + with patch('model_filemanager.check_file_exists', return_value=None): + # Call the function + result = await download_model( + mock_get, + 'model.safetensors', + 'http://example.com/model.safetensors', + 'mock_directory', + mock_progress_callback + ) + + # Assert the expected behavior + assert isinstance(result, DownloadModelStatus) + assert result.status == 'error' + assert result.message == 'Failed to download model.safetensors. Status code: 404' + assert result.already_existed is False + + # Check that progress_callback was called with the correct arguments + mock_progress_callback.assert_any_call( + 'mock_directory/model.safetensors', + DownloadModelStatus( + status=DownloadStatusType.PENDING, + progress_percentage=0, + message='Starting download of model.safetensors', + already_existed=False + ) + ) + mock_progress_callback.assert_called_with( + 'mock_directory/model.safetensors', + DownloadModelStatus( + status=DownloadStatusType.ERROR, + progress_percentage=0, + message='Failed to download model.safetensors. Status code: 404', + already_existed=False + ) + ) + + # Verify that the get method was called with the correct URL + mock_get.assert_called_once_with('http://example.com/model.safetensors') + +@pytest.mark.asyncio +async def test_download_model_invalid_model_subdirectory(): + + mock_make_request = AsyncMock() + mock_progress_callback = AsyncMock() + + + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + '../bad_path', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelStatus) + assert result.message == 'Invalid model subdirectory' + assert result.status == 'error' + assert result.already_existed is False + + +# For create_model_path function +def test_create_model_path(tmp_path, monkeypatch): + mock_models_dir = tmp_path / "models" + monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) + + model_name = "test_model.sft" + model_directory = "test_dir" + + file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) + + assert file_path == str(mock_models_dir / model_directory / model_name) + assert relative_path == f"{model_directory}/{model_name}" + assert os.path.exists(os.path.dirname(file_path)) + + +@pytest.mark.asyncio +async def test_check_file_exists_when_file_exists(tmp_path): + file_path = tmp_path / "existing_model.sft" + file_path.touch() # Create an empty file + + mock_callback = AsyncMock() + + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") + + assert result is not None + assert result.status == "completed" + assert result.message == "existing_model.sft already exists" + assert result.already_existed is True + + mock_callback.assert_called_once_with( + "test/existing_model.sft", + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) + ) + +@pytest.mark.asyncio +async def test_check_file_exists_when_file_does_not_exist(tmp_path): + file_path = tmp_path / "non_existing_model.sft" + + mock_callback = AsyncMock() + + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") + + assert result is None + mock_callback.assert_not_called() + +@pytest.mark.asyncio +async def test_track_download_progress_no_content_length(): + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.headers = {} # No Content-Length header + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) + + mock_callback = AsyncMock() + mock_open = MagicMock(return_value=MagicMock()) + + with patch('builtins.open', mock_open): + result = await track_download_progress( + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=0.1 + ) + + assert result.status == "completed" + # Check that progress was reported even without knowing the total size + mock_callback.assert_any_call( + 'models/model.sft', + DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) + ) + +@pytest.mark.asyncio +async def test_track_download_progress_interval(): + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.headers = {'Content-Length': '1000'} + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) + + mock_callback = AsyncMock() + mock_open = MagicMock(return_value=MagicMock()) + + # Create a mock time function that returns incremental float values + mock_time = MagicMock() + mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks + + with patch('builtins.open', mock_open), \ + patch('time.time', mock_time): + await track_download_progress( + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=1.0 + ) + + # Print out the actual call count and the arguments of each call for debugging + print(f"mock_callback was called {mock_callback.call_count} times") + for i, call in enumerate(mock_callback.call_args_list): + args, kwargs = call + print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") + + # Assert that progress was updated at least 3 times (start, at least one interval, and end) + assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" + + # Verify the first and last calls + first_call = mock_callback.call_args_list[0] + assert first_call[0][1].status == "in_progress" + # Allow for some initial progress, but it should be less than 50% + assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%" + + last_call = mock_callback.call_args_list[-1] + assert last_call[0][1].status == "completed" + assert last_call[0][1].progress_percentage == 100 + +def test_valid_subdirectory(): + assert validate_model_subdirectory("valid-model123") is True + +def test_subdirectory_too_long(): + assert validate_model_subdirectory("a" * 51) is False + +def test_subdirectory_with_double_dots(): + assert validate_model_subdirectory("model/../unsafe") is False + +def test_subdirectory_with_slash(): + assert validate_model_subdirectory("model/unsafe") is False + +def test_subdirectory_with_special_characters(): + assert validate_model_subdirectory("model@unsafe") is False + +def test_subdirectory_with_underscore_and_dash(): + assert validate_model_subdirectory("valid_model-name") is True + +def test_empty_subdirectory(): + assert validate_model_subdirectory("") is False + +@pytest.mark.parametrize("filename, expected", [ + ("valid_model.safetensors", True), + ("valid_model.sft", True), + ("valid model.safetensors", True), # Test with space + ("UPPERCASE_MODEL.SAFETENSORS", True), + ("model_with.multiple.dots.pt", False), + ("", False), # Empty string + ("../../../etc/passwd", False), # Path traversal attempt + ("/etc/passwd", False), # Absolute path + ("\\windows\\system32\\config\\sam", False), # Windows path + (".hidden_file.pt", False), # Hidden file + ("invalid.ckpt", False), # Invalid character + ("invalid?.ckpt", False), # Another invalid character + ("very" * 100 + ".safetensors", False), # Too long filename + ("\nmodel_with_newline.pt", False), # Newline character + ("model_with_emoji😊.pt", False), # Emoji in filename +]) +def test_validate_filename(filename, expected): + assert validate_filename(filename) == expected diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 0587502f..d70d00f4 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -1 +1,3 @@ pytest>=7.8.0 +pytest-aiohttp +pytest-asyncio