Internal download API: Add proper validated directory input (#4981)

* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
This commit is contained in:
Alex "mcmonkey" Goodwin 2024-09-24 16:50:45 +09:00 committed by GitHub
parent 479a427a48
commit 08c8968482
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 190 additions and 180 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@ -3,7 +3,7 @@ import aiohttp
import os import os
import traceback import traceback
import logging import logging
from folder_paths import models_dir from folder_paths import folder_names_and_paths, get_folder_paths
import re import re
from typing import Callable, Any, Optional, Awaitable, Dict from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum from enum import Enum
@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed" COMPLETED = "completed"
ERROR = "error" ERROR = "error"
@dataclass @dataclass
class DownloadModelStatus(): class DownloadModelStatus():
status: str status: str
@ -29,7 +30,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"status": self.status, "status": self.status,
@ -38,102 +39,112 @@ class DownloadModelStatus():
"already_existed": self.already_existed "already_existed": self.already_existed
} }
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus: progress_interval: float = 1.0) -> DownloadModelStatus:
""" """
Download a model file from a given URL into the models directory. Download a model file from a given URL into the models directory.
Args: Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests. A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str): model_name (str):
The name of the model file to be downloaded. This will be the filename on disk. The name of the model file to be downloaded. This will be the filename on disk.
model_url (str): model_url (str):
The URL from which to download the model. The URL from which to download the model.
model_sub_directory (str): model_directory (str):
The subdirectory within the main models directory where the model The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.). should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates. An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns: Returns:
DownloadModelStatus: The result of the download operation. DownloadModelStatus: The result of the download operation.
""" """
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name): if not validate_filename(model_name):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0, 0,
"Invalid model name", "Invalid model name",
False False
) )
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) if not model_directory in folder_names_and_paths:
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file: if existing_file:
return existing_file return existing_file
try: try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
response = await model_download_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message) logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e: except Exception as e:
logging.error(f"Error in downloading model: {e}") logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory) def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory # Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path) abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir)) abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}") raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
relative_path = '/'.join([model_directory, model_name]) async def check_file_exists(file_path: str,
return file_path, relative_path model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
async def check_file_exists(file_path: str, ) -> Optional[DownloadModelStatus]:
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str, file_path: str,
model_name: str, model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus: interval: float = 1.0) -> DownloadModelStatus:
try: try:
total_size = int(response.headers.get('Content-Length', 0)) total_size = int(response.headers.get('Content-Length', 0))
@ -144,10 +155,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
nonlocal last_update_time nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0 progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
last_update_time = time.time() last_update_time = time.time()
with open(file_path, 'wb') as f: temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192) chunk_iterator = response.content.iter_chunked(8192)
while True: while True:
try: try:
@ -156,58 +168,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break break
f.write(chunk) f.write(chunk)
downloaded += len(chunk) downloaded += len(chunk)
if time.time() - last_update_time >= interval: if time.time() - last_update_time >= interval:
await update_progress() await update_progress()
os.rename(temp_file_path, file_path)
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}") logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str, async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any], model_name: str,
relative_path: str) -> DownloadModelStatus: progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool: def validate_filename(filename: str)-> bool:
""" """
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args: Args:
filename (str): The filename to validate filename (str): The filename to validate

View File

@ -689,10 +689,11 @@ class PromptServer():
data = await request.json() data = await request.json()
url = data.get('url') url = data.get('url')
model_directory = data.get('model_directory') model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename') model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session session = self.client_session
@ -700,7 +701,7 @@ class PromptServer():
logging.error("Client session is not initialized") logging.error("Client session is not initialized")
return web.Response(status=500) return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task await task
return web.json_response(task.result().to_dict()) return web.json_response(task.result().to_dict())

View File

@ -1,10 +1,17 @@
import pytest import pytest
import tempfile
import aiohttp import aiohttp
from aiohttp import ClientResponse from aiohttp import ClientResponse
import itertools import itertools
import os import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock: class AsyncIteratorMock:
""" """
@ -42,7 +49,7 @@ class ContentMock:
return AsyncIteratorMock(self.chunks) return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200 mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'} mock_response.headers = {'Content-Length': '1000'}
@ -53,15 +60,13 @@ async def test_download_model_success():
mock_make_request = AsyncMock(return_value=mock_response) mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1) time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \ patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \ patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model( result = await download_model(
@ -69,6 +74,7 @@ async def test_download_model_success():
'model.sft', 'model.sft',
'http://example.com/model.sft', 'http://example.com/model.sft',
'checkpoints', 'checkpoints',
temp_dir,
mock_progress_callback mock_progress_callback
) )
@ -83,44 +89,48 @@ async def test_download_model_success():
# Check initial call # Check initial call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
) )
# Check final call # Check final call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
) )
# Verify file writing mock_file_path = os.path.join(temp_dir, 'model.sft')
mock_file.write.assert_any_call(b'a' * 500) assert os.path.exists(mock_file_path)
mock_file.write.assert_any_call(b'b' * 300) with open(mock_file_path, 'rb') as mock_file:
mock_file.write.assert_any_call(b'c' * 200) assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made # Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft') mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_url_request_failure(): async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies # Mock dependencies
mock_response = AsyncMock(spec=ClientResponse) mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response) mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function # Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
# Mock the check_file_exists function to return None (file doesn't exist) patch('model_filemanager.check_file_exists', return_value=None), \
with patch('model_filemanager.check_file_exists', return_value=None): patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function # Call the function
result = await download_model( result = await download_model(
mock_get, mock_get,
'model.safetensors', 'model.safetensors',
'http://example.com/model.safetensors', 'http://example.com/model.safetensors',
'mock_directory', 'checkpoints',
mock_progress_callback temp_dir,
) mock_progress_callback
)
# Assert the expected behavior # Assert the expected behavior
assert isinstance(result, DownloadModelStatus) assert isinstance(result, DownloadModelStatus)
@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
# Check that progress_callback was called with the correct arguments # Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors', 'model.safetensors',
DownloadModelStatus( DownloadModelStatus(
status=DownloadStatusType.PENDING, status=DownloadStatusType.PENDING,
progress_percentage=0, progress_percentage=0,
@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
) )
) )
mock_progress_callback.assert_called_with( mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors', 'model.safetensors',
DownloadModelStatus( DownloadModelStatus(
status=DownloadStatusType.ERROR, status=DownloadStatusType.ERROR,
progress_percentage=0, progress_percentage=0,
@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory(): async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock() mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
result = await download_model( result = await download_model(
mock_make_request, mock_make_request,
'model.sft', 'model.sft',
'http://example.com/model.sft', 'http://example.com/model.sft',
'../bad_path', '../bad_path',
'../bad_path',
mock_progress_callback mock_progress_callback
) )
# Assert the result # Assert the result
assert isinstance(result, DownloadModelStatus) assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory' assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error' assert result.status == 'error'
assert result.already_existed is False assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch): def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models" model_name = "model.safetensors"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) folder_path = os.path.join(tmp_path, "mock_dir")
model_name = "test_model.sft" file_path = create_model_path(model_name, folder_path)
model_directory = "test_dir"
assert file_path == os.path.join(folder_path, "model.safetensors")
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path)) assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path): async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft" file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file file_path.touch() # Create an empty file
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None assert result is not None
assert result.status == "completed" assert result.status == "completed"
assert result.message == "existing_model.sft already exists" assert result.message == "existing_model.sft already exists"
assert result.already_existed is True assert result.already_existed is True
mock_callback.assert_called_once_with( mock_callback.assert_called_once_with(
"test/existing_model.sft", "existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path): async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft" file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None assert result is None
mock_callback.assert_not_called() mock_callback.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_no_content_length(): async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock() mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open): full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft', result = await track_download_progress(
mock_callback, 'models/model.sft', interval=0.1 mock_response, full_path, 'model.sft',
) mock_callback, interval=0.1
)
assert result.status == "completed" assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size # Check that progress was reported even without knowing the total size
mock_callback.assert_any_call( mock_callback.assert_any_call(
'models/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_interval(): async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'} mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock() mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock()) mock_open = MagicMock(return_value=MagicMock())
@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
mock_time = MagicMock() mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \ full_path = os.path.join(temp_dir, 'model.sft')
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging with patch('time.time', mock_time):
print(f"mock_callback was called {mock_callback.call_count} times") await track_download_progress(
for i, call in enumerate(mock_callback.call_args_list): mock_response, full_path, 'model.sft',
args, kwargs = call mock_callback, interval=1.0
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") )
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end) # Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
assert last_call[0][1].status == "completed" assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100 assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [ @pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True), ("valid_model.safetensors", True),
("valid_model.sft", True), ("valid_model.sft", True),