Add model downloading endpoint. (#4248)
* Add model downloading endpoint.
* Move client session init to async function.
* Break up large function.
* Send "download_progress" as websocket event.
* Fixed
* Fixed.
* Use async mock.
* Move server set up to right before run call.
* Validate that model subdirectory cannot contain relative paths.
* Add download_model test checking for invalid paths.
* Remove DS_Store.
* Consolidate DownloadStatus and DownloadModelResult
* Add progress_interval as an optional parameter.
* Use tuple type from annotations.
* Use pydantic.
* Update comment.
* Revert "Use pydantic."
This reverts commit 7461e8eb00
.
* Add new line.
* Add newline EOF.
* Validate model filename as well.
* Add comment to not reply on internal.
* Restrict downloading to safetensor files only.
This commit is contained in:
parent
34608de2e9
commit
3e52e0364c
1
main.py
1
main.py
|
@ -261,6 +261,7 @@ if __name__ == "__main__":
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
loop.run_until_complete(server.setup())
|
||||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
# model_manager/__init__.py
|
||||||
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
|
@ -0,0 +1,240 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import aiohttp
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import logging
|
||||||
|
from folder_paths import models_dir
|
||||||
|
import re
|
||||||
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
|
from enum import Enum
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadStatusType(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadModelStatus():
|
||||||
|
status: str
|
||||||
|
progress_percentage: float
|
||||||
|
message: str
|
||||||
|
already_existed: bool = False
|
||||||
|
|
||||||
|
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||||
|
self.status = status.value # Store the string value of the Enum
|
||||||
|
self.progress_percentage = progress_percentage
|
||||||
|
self.message = message
|
||||||
|
self.already_existed = already_existed
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"status": self.status,
|
||||||
|
"progress_percentage": self.progress_percentage,
|
||||||
|
"message": self.message,
|
||||||
|
"already_existed": self.already_existed
|
||||||
|
}
|
||||||
|
|
||||||
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
|
model_name: str,
|
||||||
|
model_url: str,
|
||||||
|
model_sub_directory: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
|
"""
|
||||||
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
|
model_name (str):
|
||||||
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
|
model_url (str):
|
||||||
|
The URL from which to download the model.
|
||||||
|
model_sub_directory (str):
|
||||||
|
The subdirectory within the main models directory where the model
|
||||||
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
|
An asynchronous function to call with progress updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DownloadModelStatus: The result of the download operation.
|
||||||
|
"""
|
||||||
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid model subdirectory",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validate_filename(model_name):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid model name",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||||
|
if existing_file:
|
||||||
|
return existing_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
response = await model_download_request(model_url)
|
||||||
|
if response.status != 200:
|
||||||
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
|
logging.error(error_message)
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in downloading model: {e}")
|
||||||
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||||
|
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||||
|
os.makedirs(full_model_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(full_model_dir, model_name)
|
||||||
|
|
||||||
|
# Ensure the resulting path is still within the base directory
|
||||||
|
abs_file_path = os.path.abspath(file_path)
|
||||||
|
abs_base_dir = os.path.abspath(str(models_base_dir))
|
||||||
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
|
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
|
return file_path, relative_path
|
||||||
|
|
||||||
|
async def check_file_exists(file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return status
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
|
file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str,
|
||||||
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
|
try:
|
||||||
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
downloaded = 0
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
async def update_progress():
|
||||||
|
nonlocal last_update_time
|
||||||
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await chunk_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
|
||||||
|
if time.time() - last_update_time >= interval:
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
await update_progress()
|
||||||
|
|
||||||
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
async def handle_download_error(e: Exception,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||||
|
relative_path: str) -> DownloadModelStatus:
|
||||||
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the model subdirectory is safe to install into.
|
||||||
|
Must not contain relative paths, nested paths or special characters
|
||||||
|
other than underscores and hyphens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_subdirectory (str): The subdirectory for the specific model type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the subdirectory is safe, False otherwise.
|
||||||
|
"""
|
||||||
|
if len(model_subdirectory) > 50:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_filename(filename: str)-> bool:
|
||||||
|
"""
|
||||||
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The filename to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the filename is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not filename.lower().endswith(('.sft', '.safetensors')):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the filename is empty, None, or just whitespace
|
||||||
|
if not filename or not filename.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for any directory traversal attempts or invalid characters
|
||||||
|
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the filename starts with a dot (hidden file)
|
||||||
|
if filename.startswith('.'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use a whitelist of allowed characters
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Ensure the filename isn't too long
|
||||||
|
if len(filename) > 255:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
35
server.py
35
server.py
|
@ -12,7 +12,6 @@ import json
|
||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
import hashlib
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -28,7 +27,8 @@ import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
|
from model_filemanager import download_model, DownloadModelStatus
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
@ -76,6 +76,7 @@ class PromptServer():
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
|
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||||
self.number = 0
|
self.number = 0
|
||||||
|
|
||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
|
@ -559,6 +560,36 @@ class PromptServer():
|
||||||
self.prompt_queue.delete_history_item(id_to_delete)
|
self.prompt_queue.delete_history_item(id_to_delete)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
@routes.post("/internal/models/download")
|
||||||
|
async def download_handler(request):
|
||||||
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
|
await self.send_json("download_progress", status.to_dict())
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
url = data.get('url')
|
||||||
|
model_directory = data.get('model_directory')
|
||||||
|
model_filename = data.get('model_filename')
|
||||||
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
|
if not url or not model_directory or not model_filename:
|
||||||
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
|
session = self.client_session
|
||||||
|
if session is None:
|
||||||
|
logging.error("Client session is not initialized")
|
||||||
|
return web.Response(status=500)
|
||||||
|
|
||||||
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
||||||
|
await task
|
||||||
|
|
||||||
|
return web.json_response(task.result().to_dict())
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
|
|
|
@ -0,0 +1,321 @@
|
||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import ClientResponse
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
|
||||||
|
class AsyncIteratorMock:
|
||||||
|
"""
|
||||||
|
A mock class that simulates an asynchronous iterator.
|
||||||
|
This is used to mimic the behavior of aiohttp's content iterator.
|
||||||
|
"""
|
||||||
|
def __init__(self, seq):
|
||||||
|
# Convert the input sequence into an iterator
|
||||||
|
self.iter = iter(seq)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
# This method is called when 'async for' is used
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
# This method is called for each iteration in an 'async for' loop
|
||||||
|
try:
|
||||||
|
return next(self.iter)
|
||||||
|
except StopIteration:
|
||||||
|
# This is the asynchronous equivalent of StopIteration
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
class ContentMock:
|
||||||
|
"""
|
||||||
|
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||||
|
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||||
|
"""
|
||||||
|
def __init__(self, chunks):
|
||||||
|
# Store the chunks that will be returned by the iterator
|
||||||
|
self.chunks = chunks
|
||||||
|
|
||||||
|
def iter_chunked(self, chunk_size):
|
||||||
|
# This method mimics aiohttp's content.iter_chunked()
|
||||||
|
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||||
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_success():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
|
# Create a mock for content that returns an async iterator directly
|
||||||
|
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
||||||
|
mock_response.content = ContentMock(chunks)
|
||||||
|
|
||||||
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Mock file operations
|
||||||
|
mock_open = MagicMock()
|
||||||
|
mock_file = MagicMock()
|
||||||
|
mock_open.return_value.__enter__.return_value = mock_file
|
||||||
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||||
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
|
patch('builtins.open', mock_open), \
|
||||||
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message == 'Successfully downloaded model.sft'
|
||||||
|
assert result.status == 'completed'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
# Check progress callback calls
|
||||||
|
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||||
|
|
||||||
|
# Check initial call
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'checkpoints/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check final call
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'checkpoints/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify file writing
|
||||||
|
mock_file.write.assert_any_call(b'a' * 500)
|
||||||
|
mock_file.write.assert_any_call(b'b' * 300)
|
||||||
|
mock_file.write.assert_any_call(b'c' * 200)
|
||||||
|
|
||||||
|
# Verify request was made
|
||||||
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_url_request_failure():
|
||||||
|
# Mock dependencies
|
||||||
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the create_model_path function
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||||
|
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||||
|
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||||
|
# Call the function
|
||||||
|
result = await download_model(
|
||||||
|
mock_get,
|
||||||
|
'model.safetensors',
|
||||||
|
'http://example.com/model.safetensors',
|
||||||
|
'mock_directory',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the expected behavior
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
# Check that progress_callback was called with the correct arguments
|
||||||
|
mock_progress_callback.assert_any_call(
|
||||||
|
'mock_directory/model.safetensors',
|
||||||
|
DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.PENDING,
|
||||||
|
progress_percentage=0,
|
||||||
|
message='Starting download of model.safetensors',
|
||||||
|
already_existed=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_progress_callback.assert_called_with(
|
||||||
|
'mock_directory/model.safetensors',
|
||||||
|
DownloadModelStatus(
|
||||||
|
status=DownloadStatusType.ERROR,
|
||||||
|
progress_percentage=0,
|
||||||
|
message='Failed to download model.safetensors. Status code: 404',
|
||||||
|
already_existed=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the get method was called with the correct URL
|
||||||
|
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'../bad_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message == 'Invalid model subdirectory'
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
|
||||||
|
# For create_model_path function
|
||||||
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
|
mock_models_dir = tmp_path / "models"
|
||||||
|
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||||
|
|
||||||
|
model_name = "test_model.sft"
|
||||||
|
model_directory = "test_dir"
|
||||||
|
|
||||||
|
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||||
|
|
||||||
|
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||||
|
assert relative_path == f"{model_directory}/{model_name}"
|
||||||
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
|
file_path = tmp_path / "existing_model.sft"
|
||||||
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.message == "existing_model.sft already exists"
|
||||||
|
assert result.already_existed is True
|
||||||
|
|
||||||
|
mock_callback.assert_called_once_with(
|
||||||
|
"test/existing_model.sft",
|
||||||
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_download_progress_no_content_length():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.headers = {} # No Content-Length header
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch('builtins.open', mock_open):
|
||||||
|
result = await track_download_progress(
|
||||||
|
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||||
|
mock_callback, 'models/model.sft', interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
# Check that progress was reported even without knowing the total size
|
||||||
|
mock_callback.assert_any_call(
|
||||||
|
'models/model.sft',
|
||||||
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_track_download_progress_interval():
|
||||||
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
||||||
|
|
||||||
|
mock_callback = AsyncMock()
|
||||||
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Create a mock time function that returns incremental float values
|
||||||
|
mock_time = MagicMock()
|
||||||
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
|
with patch('builtins.open', mock_open), \
|
||||||
|
patch('time.time', mock_time):
|
||||||
|
await track_download_progress(
|
||||||
|
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||||
|
mock_callback, 'models/model.sft', interval=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print out the actual call count and the arguments of each call for debugging
|
||||||
|
print(f"mock_callback was called {mock_callback.call_count} times")
|
||||||
|
for i, call in enumerate(mock_callback.call_args_list):
|
||||||
|
args, kwargs = call
|
||||||
|
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
||||||
|
|
||||||
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
|
|
||||||
|
# Verify the first and last calls
|
||||||
|
first_call = mock_callback.call_args_list[0]
|
||||||
|
assert first_call[0][1].status == "in_progress"
|
||||||
|
# Allow for some initial progress, but it should be less than 50%
|
||||||
|
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
||||||
|
|
||||||
|
last_call = mock_callback.call_args_list[-1]
|
||||||
|
assert last_call[0][1].status == "completed"
|
||||||
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
|
def test_valid_subdirectory():
|
||||||
|
assert validate_model_subdirectory("valid-model123") is True
|
||||||
|
|
||||||
|
def test_subdirectory_too_long():
|
||||||
|
assert validate_model_subdirectory("a" * 51) is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_double_dots():
|
||||||
|
assert validate_model_subdirectory("model/../unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_slash():
|
||||||
|
assert validate_model_subdirectory("model/unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_special_characters():
|
||||||
|
assert validate_model_subdirectory("model@unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_underscore_and_dash():
|
||||||
|
assert validate_model_subdirectory("valid_model-name") is True
|
||||||
|
|
||||||
|
def test_empty_subdirectory():
|
||||||
|
assert validate_model_subdirectory("") is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
|
("valid_model.safetensors", True),
|
||||||
|
("valid_model.sft", True),
|
||||||
|
("valid model.safetensors", True), # Test with space
|
||||||
|
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||||
|
("model_with.multiple.dots.pt", False),
|
||||||
|
("", False), # Empty string
|
||||||
|
("../../../etc/passwd", False), # Path traversal attempt
|
||||||
|
("/etc/passwd", False), # Absolute path
|
||||||
|
("\\windows\\system32\\config\\sam", False), # Windows path
|
||||||
|
(".hidden_file.pt", False), # Hidden file
|
||||||
|
("invalid<char>.ckpt", False), # Invalid character
|
||||||
|
("invalid?.ckpt", False), # Another invalid character
|
||||||
|
("very" * 100 + ".safetensors", False), # Too long filename
|
||||||
|
("\nmodel_with_newline.pt", False), # Newline character
|
||||||
|
("model_with_emoji😊.pt", False), # Emoji in filename
|
||||||
|
])
|
||||||
|
def test_validate_filename(filename, expected):
|
||||||
|
assert validate_filename(filename) == expected
|
|
@ -1 +1,3 @@
|
||||||
pytest>=7.8.0
|
pytest>=7.8.0
|
||||||
|
pytest-aiohttp
|
||||||
|
pytest-asyncio
|
||||||
|
|
Loading…
Reference in New Issue