diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index e318351c..b7ac1625 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 712d5932..5ffec395 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -3,7 +3,7 @@ import aiohttp import os import traceback import logging -from folder_paths import models_dir +from folder_paths import folder_names_and_paths, get_folder_paths import re from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum @@ -17,6 +17,7 @@ class DownloadStatusType(Enum): COMPLETED = "completed" ERROR = "error" + @dataclass class DownloadModelStatus(): status: str @@ -29,7 +30,7 @@ class DownloadModelStatus(): self.progress_percentage = progress_percentage self.message = message self.already_existed = already_existed - + def to_dict(self) -> Dict[str, Any]: return { "status": self.status, @@ -38,102 +39,112 @@ class DownloadModelStatus(): "already_existed": self.already_existed } + async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], - model_name: str, - model_url: str, - model_sub_directory: str, + model_name: str, + model_url: str, + model_directory: str, + folder_path: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_interval: float = 1.0) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. Args: - model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): + model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): A function that makes an HTTP request. This makes it easier to mock in unit tests. - model_name (str): + model_name (str): The name of the model file to be downloaded. This will be the filename on disk. - model_url (str): + model_url (str): The URL from which to download the model. - model_sub_directory (str): - The subdirectory within the main models directory where the model + model_directory (str): + The subdirectory within the main models directory where the model should be saved (e.g., 'checkpoints', 'loras', etc.). - progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): + progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. + folder_path (str); + Path to which model folder should be used as the root. Returns: DownloadModelStatus: The result of the download operation. """ - if not validate_model_subdirectory(model_sub_directory): - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid model subdirectory", - False - ) - if not validate_filename(model_name): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model name", + "Invalid model name", False ) - file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) - existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) + if not model_directory in folder_names_and_paths: + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", + False + ) + + if not folder_path in get_folder_paths(model_directory): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.", + False + ) + + file_path = create_model_path(model_name, folder_path) + existing_file = await check_file_exists(file_path, model_name, progress_callback) if existing_file: return existing_file try: + logging.info(f"Downloading {model_name} from {model_url}") status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) + return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval) except Exception as e: logging.error(f"Error in downloading model: {e}") - return await handle_download_error(e, model_name, progress_callback, relative_path) - + return await handle_download_error(e, model_name, progress_callback) -def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: - full_model_dir = os.path.join(models_base_dir, model_directory) - os.makedirs(full_model_dir, exist_ok=True) - file_path = os.path.join(full_model_dir, model_name) + +def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]: + os.makedirs(folder_path, exist_ok=True) + file_path = os.path.join(folder_path, model_name) # Ensure the resulting path is still within the base directory abs_file_path = os.path.abspath(file_path) - abs_base_dir = os.path.abspath(str(models_base_dir)) + abs_base_dir = os.path.abspath(folder_path) if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: - raise Exception(f"Invalid model directory: {model_directory}/{model_name}") + raise Exception(f"Invalid model directory: {folder_path}/{model_name}") + + return file_path - relative_path = '/'.join([model_directory, model_name]) - return file_path, relative_path - -async def check_file_exists(file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str) -> Optional[DownloadModelStatus]: +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]] + ) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status return None -async def track_download_progress(response: aiohttp.ClientResponse, - file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str, +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) @@ -144,10 +155,11 @@ async def track_download_progress(response: aiohttp.ClientResponse, nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) last_update_time = time.time() - with open(file_path, 'wb') as f: + temp_file_path = file_path + '.tmp' + with open(temp_file_path, 'wb') as f: chunk_iterator = response.content.iter_chunked(8192) while True: try: @@ -156,58 +168,39 @@ async def track_download_progress(response: aiohttp.ClientResponse, break f.write(chunk) downloaded += len(chunk) - + if time.time() - last_update_time >= interval: await update_progress() + os.rename(temp_file_path, file_path) + await update_progress() - + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status except Exception as e: logging.error(f"Error in track_download_progress: {e}") logging.error(traceback.format_exc()) - return await handle_download_error(e, model_name, progress_callback, relative_path) + return await handle_download_error(e, model_name, progress_callback) -async def handle_download_error(e: Exception, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any], - relative_path: str) -> DownloadModelStatus: + +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any] + ) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status -def validate_model_subdirectory(model_subdirectory: str) -> bool: - """ - Validate that the model subdirectory is safe to install into. - Must not contain relative paths, nested paths or special characters - other than underscores and hyphens. - - Args: - model_subdirectory (str): The subdirectory for the specific model type. - - Returns: - bool: True if the subdirectory is safe, False otherwise. - """ - if len(model_subdirectory) > 50: - return False - - if '..' in model_subdirectory or '/' in model_subdirectory: - return False - - if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): - return False - - return True def validate_filename(filename: str)-> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. - + Args: filename (str): The filename to validate diff --git a/server.py b/server.py index ea923e85..f1971f2d 100644 --- a/server.py +++ b/server.py @@ -689,10 +689,11 @@ class PromptServer(): data = await request.json() url = data.get('url') model_directory = data.get('model_directory') + folder_path = data.get('folder_path') model_filename = data.get('model_filename') progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. - if not url or not model_directory or not model_filename: + if not url or not model_directory or not model_filename or not folder_path: return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) session = self.client_session @@ -700,7 +701,7 @@ class PromptServer(): logging.error("Client session is not initialized") return web.Response(status=500) - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval)) await task return web.json_response(task.result().to_dict()) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 66150a46..128dfeb9 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -1,10 +1,17 @@ import pytest +import tempfile import aiohttp from aiohttp import ClientResponse import itertools -import os +import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename +from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename +import folder_paths + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname class AsyncIteratorMock: """ @@ -42,7 +49,7 @@ class ContentMock: return AsyncIteratorMock(self.chunks) @pytest.mark.asyncio -async def test_download_model_success(): +async def test_download_model_success(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.status = 200 mock_response.headers = {'Content-Length': '1000'} @@ -53,15 +60,13 @@ async def test_download_model_success(): mock_make_request = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() - # Mock file operations - mock_open = MagicMock() - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ + fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} + + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('builtins.open', mock_open), \ + patch('folder_paths.folder_names_and_paths', fake_paths), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -69,6 +74,7 @@ async def test_download_model_success(): 'model.sft', 'http://example.com/model.sft', 'checkpoints', + temp_dir, mock_progress_callback ) @@ -83,44 +89,48 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( - 'checkpoints/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) ) # Check final call mock_progress_callback.assert_any_call( - 'checkpoints/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) - # Verify file writing - mock_file.write.assert_any_call(b'a' * 500) - mock_file.write.assert_any_call(b'b' * 300) - mock_file.write.assert_any_call(b'c' * 200) + mock_file_path = os.path.join(temp_dir, 'model.sft') + assert os.path.exists(mock_file_path) + with open(mock_file_path, 'rb') as mock_file: + assert mock_file.read() == b''.join(chunks) + os.remove(mock_file_path) # Verify request was made mock_make_request.assert_called_once_with('http://example.com/model.sft') @pytest.mark.asyncio -async def test_download_model_url_request_failure(): +async def test_download_model_url_request_failure(temp_dir): # Mock dependencies mock_response = AsyncMock(spec=ClientResponse) mock_response.status = 404 # Simulate a "Not Found" error mock_get = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() + + fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): - # Mock the check_file_exists function to return None (file doesn't exist) - with patch('model_filemanager.check_file_exists', return_value=None): - # Call the function - result = await download_model( - mock_get, - 'model.safetensors', - 'http://example.com/model.safetensors', - 'mock_directory', - mock_progress_callback - ) + with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ + patch('model_filemanager.check_file_exists', return_value=None), \ + patch('folder_paths.folder_names_and_paths', fake_paths): + # Call the function + result = await download_model( + mock_get, + 'model.safetensors', + 'http://example.com/model.safetensors', + 'checkpoints', + temp_dir, + mock_progress_callback + ) # Assert the expected behavior assert isinstance(result, DownloadModelStatus) @@ -130,7 +140,7 @@ async def test_download_model_url_request_failure(): # Check that progress_callback was called with the correct arguments mock_progress_callback.assert_any_call( - 'mock_directory/model.safetensors', + 'model.safetensors', DownloadModelStatus( status=DownloadStatusType.PENDING, progress_percentage=0, @@ -139,7 +149,7 @@ async def test_download_model_url_request_failure(): ) ) mock_progress_callback.assert_called_with( - 'mock_directory/model.safetensors', + 'model.safetensors', DownloadModelStatus( status=DownloadStatusType.ERROR, progress_percentage=0, @@ -153,98 +163,125 @@ async def test_download_model_url_request_failure(): @pytest.mark.asyncio async def test_download_model_invalid_model_subdirectory(): - mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() - result = await download_model( mock_make_request, 'model.sft', 'http://example.com/model.sft', '../bad_path', + '../bad_path', mock_progress_callback ) # Assert the result assert isinstance(result, DownloadModelStatus) - assert result.message == 'Invalid model subdirectory' + assert result.message.startswith('Invalid or unrecognized model directory') assert result.status == 'error' assert result.already_existed is False +@pytest.mark.asyncio +async def test_download_model_invalid_folder_path(): + mock_make_request = AsyncMock() + mock_progress_callback = AsyncMock() + + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + 'checkpoints', + 'invalid_path', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelStatus) + assert result.message.startswith("Invalid folder path") + assert result.status == 'error' + assert result.already_existed is False -# For create_model_path function def test_create_model_path(tmp_path, monkeypatch): - mock_models_dir = tmp_path / "models" - monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) - - model_name = "test_model.sft" - model_directory = "test_dir" - - file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) - - assert file_path == str(mock_models_dir / model_directory / model_name) - assert relative_path == f"{model_directory}/{model_name}" + model_name = "model.safetensors" + folder_path = os.path.join(tmp_path, "mock_dir") + + file_path = create_model_path(model_name, folder_path) + + assert file_path == os.path.join(folder_path, "model.safetensors") assert os.path.exists(os.path.dirname(file_path)) + with pytest.raises(Exception, match="Invalid model directory"): + create_model_path("../path_traversal.safetensors", folder_path) + + with pytest.raises(Exception, match="Invalid model directory"): + create_model_path("/etc/some_root_path", folder_path) + @pytest.mark.asyncio async def test_check_file_exists_when_file_exists(tmp_path): file_path = tmp_path / "existing_model.sft" file_path.touch() # Create an empty file - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") - + + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback) + assert result is not None assert result.status == "completed" assert result.message == "existing_model.sft already exists" assert result.already_existed is True - + mock_callback.assert_called_once_with( - "test/existing_model.sft", + "existing_model.sft", DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) ) @pytest.mark.asyncio async def test_check_file_exists_when_file_does_not_exist(tmp_path): file_path = tmp_path / "non_existing_model.sft" - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") - + + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback) + assert result is None mock_callback.assert_not_called() @pytest.mark.asyncio -async def test_track_download_progress_no_content_length(): +async def test_track_download_progress_no_content_length(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {} # No Content-Length header - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) + chunks = [b'a' * 500, b'b' * 500] + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() - mock_open = MagicMock(return_value=MagicMock()) - with patch('builtins.open', mock_open): - result = await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, 'models/model.sft', interval=0.1 - ) + full_path = os.path.join(temp_dir, 'model.sft') + + result = await track_download_progress( + mock_response, full_path, 'model.sft', + mock_callback, interval=0.1 + ) assert result.status == "completed" + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) + # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( - 'models/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) ) @pytest.mark.asyncio -async def test_track_download_progress_interval(): +async def test_track_download_progress_interval(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {'Content-Length': '1000'} - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) + chunks = [b'a' * 100] * 10 + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() mock_open = MagicMock(return_value=MagicMock()) @@ -253,18 +290,18 @@ async def test_track_download_progress_interval(): mock_time = MagicMock() mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks - with patch('builtins.open', mock_open), \ - patch('time.time', mock_time): - await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, 'models/model.sft', interval=1.0 - ) + full_path = os.path.join(temp_dir, 'model.sft') - # Print out the actual call count and the arguments of each call for debugging - print(f"mock_callback was called {mock_callback.call_count} times") - for i, call in enumerate(mock_callback.call_args_list): - args, kwargs = call - print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") + with patch('time.time', mock_time): + await track_download_progress( + mock_response, full_path, 'model.sft', + mock_callback, interval=1.0 + ) + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) # Assert that progress was updated at least 3 times (start, at least one interval, and end) assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" @@ -279,27 +316,6 @@ async def test_track_download_progress_interval(): assert last_call[0][1].status == "completed" assert last_call[0][1].progress_percentage == 100 -def test_valid_subdirectory(): - assert validate_model_subdirectory("valid-model123") is True - -def test_subdirectory_too_long(): - assert validate_model_subdirectory("a" * 51) is False - -def test_subdirectory_with_double_dots(): - assert validate_model_subdirectory("model/../unsafe") is False - -def test_subdirectory_with_slash(): - assert validate_model_subdirectory("model/unsafe") is False - -def test_subdirectory_with_special_characters(): - assert validate_model_subdirectory("model@unsafe") is False - -def test_subdirectory_with_underscore_and_dash(): - assert validate_model_subdirectory("valid_model-name") is True - -def test_empty_subdirectory(): - assert validate_model_subdirectory("") is False - @pytest.mark.parametrize("filename, expected", [ ("valid_model.safetensors", True), ("valid_model.sft", True),