Remove internal model download endpoint (#5432)
This commit is contained in:
parent
b666539595
commit
20879c78f9
|
@ -10,7 +10,6 @@ class InternalRoutes:
|
||||||
The top level web router for internal routes: /internal/*
|
The top level web router for internal routes: /internal/*
|
||||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||||
Check README.md for more information.
|
Check README.md for more information.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, prompt_server):
|
def __init__(self, prompt_server):
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
# model_manager/__init__.py
|
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
|
|
@ -1,234 +0,0 @@
|
||||||
#NOTE: This was an experiment and WILL BE REMOVED
|
|
||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
import logging
|
|
||||||
from folder_paths import folder_names_and_paths, get_folder_paths
|
|
||||||
import re
|
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
|
||||||
from enum import Enum
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadStatusType(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
IN_PROGRESS = "in_progress"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
ERROR = "error"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DownloadModelStatus():
|
|
||||||
status: str
|
|
||||||
progress_percentage: float
|
|
||||||
message: str
|
|
||||||
already_existed: bool = False
|
|
||||||
|
|
||||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
|
||||||
self.status = status.value # Store the string value of the Enum
|
|
||||||
self.progress_percentage = progress_percentage
|
|
||||||
self.message = message
|
|
||||||
self.already_existed = already_existed
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"status": self.status,
|
|
||||||
"progress_percentage": self.progress_percentage,
|
|
||||||
"message": self.message,
|
|
||||||
"already_existed": self.already_existed
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
|
||||||
model_name: str,
|
|
||||||
model_url: str,
|
|
||||||
model_directory: str,
|
|
||||||
folder_path: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
|
||||||
"""
|
|
||||||
Download a model file from a given URL into the models directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
|
||||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
|
||||||
model_name (str):
|
|
||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
|
||||||
model_url (str):
|
|
||||||
The URL from which to download the model.
|
|
||||||
model_directory (str):
|
|
||||||
The subdirectory within the main models directory where the model
|
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
|
||||||
An asynchronous function to call with progress updates.
|
|
||||||
folder_path (str);
|
|
||||||
Path to which model folder should be used as the root.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DownloadModelStatus: The result of the download operation.
|
|
||||||
"""
|
|
||||||
if not validate_filename(model_name):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model name",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_directory in folder_names_and_paths:
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not folder_path in get_folder_paths(model_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
file_path = create_model_path(model_name, folder_path)
|
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
|
||||||
if existing_file:
|
|
||||||
return existing_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
logging.info(f"Downloading {model_name} from {model_url}")
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
|
||||||
if response.status != 200:
|
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
|
||||||
logging.error(error_message)
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error in downloading model: {e}")
|
|
||||||
return await handle_download_error(e, model_name, progress_callback)
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
|
||||||
os.makedirs(folder_path, exist_ok=True)
|
|
||||||
file_path = os.path.join(folder_path, model_name)
|
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
|
||||||
abs_file_path = os.path.abspath(file_path)
|
|
||||||
abs_base_dir = os.path.abspath(folder_path)
|
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
|
||||||
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
async def check_file_exists(file_path: str,
|
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
|
||||||
) -> Optional[DownloadModelStatus]:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
return status
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
|
||||||
file_path: str,
|
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
|
||||||
try:
|
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
|
||||||
downloaded = 0
|
|
||||||
last_update_time = time.time()
|
|
||||||
|
|
||||||
async def update_progress():
|
|
||||||
nonlocal last_update_time
|
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
last_update_time = time.time()
|
|
||||||
|
|
||||||
temp_file_path = file_path + '.tmp'
|
|
||||||
with open(temp_file_path, 'wb') as f:
|
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = await chunk_iterator.__anext__()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
f.write(chunk)
|
|
||||||
downloaded += len(chunk)
|
|
||||||
|
|
||||||
if time.time() - last_update_time >= interval:
|
|
||||||
await update_progress()
|
|
||||||
|
|
||||||
os.rename(temp_file_path, file_path)
|
|
||||||
|
|
||||||
await update_progress()
|
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
|
|
||||||
return status
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
|
||||||
logging.error(traceback.format_exc())
|
|
||||||
return await handle_download_error(e, model_name, progress_callback)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
|
||||||
) -> DownloadModelStatus:
|
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
|
||||||
await progress_callback(model_name, status)
|
|
||||||
return status
|
|
||||||
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
|
||||||
"""
|
|
||||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (str): The filename to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the filename is valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not filename.lower().endswith(('.sft', '.safetensors')):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the filename is empty, None, or just whitespace
|
|
||||||
if not filename or not filename.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for any directory traversal attempts or invalid characters
|
|
||||||
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the filename starts with a dot (hidden file)
|
|
||||||
if filename.startswith('.'):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Use a whitelist of allowed characters
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Ensure the filename isn't too long
|
|
||||||
if len(filename) > 255:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
31
server.py
31
server.py
|
@ -29,7 +29,6 @@ import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from model_filemanager import download_model, DownloadModelStatus
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
|
@ -677,36 +676,6 @@ class PromptServer():
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
|
||||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
|
||||||
# NOTE: This was an experiment and WILL BE REMOVED
|
|
||||||
@routes.post("/internal/models/download")
|
|
||||||
async def download_handler(request):
|
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
|
||||||
payload = status.to_dict()
|
|
||||||
payload['download_path'] = filename
|
|
||||||
await self.send_json("download_progress", payload)
|
|
||||||
|
|
||||||
data = await request.json()
|
|
||||||
url = data.get('url')
|
|
||||||
model_directory = data.get('model_directory')
|
|
||||||
folder_path = data.get('folder_path')
|
|
||||||
model_filename = data.get('model_filename')
|
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename or not folder_path:
|
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
|
||||||
|
|
||||||
session = self.client_session
|
|
||||||
if session is None:
|
|
||||||
logging.error("Client session is not initialized")
|
|
||||||
return web.Response(status=500)
|
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
|
||||||
await task
|
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
|
@ -1,337 +0,0 @@
|
||||||
import pytest
|
|
||||||
import tempfile
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp import ClientResponse
|
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
|
||||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
yield tmpdirname
|
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
|
||||||
"""
|
|
||||||
A mock class that simulates an asynchronous iterator.
|
|
||||||
This is used to mimic the behavior of aiohttp's content iterator.
|
|
||||||
"""
|
|
||||||
def __init__(self, seq):
|
|
||||||
# Convert the input sequence into an iterator
|
|
||||||
self.iter = iter(seq)
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
# This method is called when 'async for' is used
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
# This method is called for each iteration in an 'async for' loop
|
|
||||||
try:
|
|
||||||
return next(self.iter)
|
|
||||||
except StopIteration:
|
|
||||||
# This is the asynchronous equivalent of StopIteration
|
|
||||||
raise StopAsyncIteration
|
|
||||||
|
|
||||||
class ContentMock:
|
|
||||||
"""
|
|
||||||
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
|
||||||
This class provides the iter_chunked method which returns an async iterator of chunks.
|
|
||||||
"""
|
|
||||||
def __init__(self, chunks):
|
|
||||||
# Store the chunks that will be returned by the iterator
|
|
||||||
self.chunks = chunks
|
|
||||||
|
|
||||||
def iter_chunked(self, chunk_size):
|
|
||||||
# This method mimics aiohttp's content.iter_chunked()
|
|
||||||
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
|
||||||
return AsyncIteratorMock(self.chunks)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_download_model_success(temp_dir):
|
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
|
||||||
mock_response.status = 200
|
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
|
||||||
# Create a mock for content that returns an async iterator directly
|
|
||||||
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
|
||||||
mock_response.content = ContentMock(chunks)
|
|
||||||
|
|
||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
|
||||||
mock_progress_callback = AsyncMock()
|
|
||||||
|
|
||||||
time_values = itertools.count(0, 0.1)
|
|
||||||
|
|
||||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
|
||||||
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
|
||||||
|
|
||||||
result = await download_model(
|
|
||||||
mock_make_request,
|
|
||||||
'model.sft',
|
|
||||||
'http://example.com/model.sft',
|
|
||||||
'checkpoints',
|
|
||||||
temp_dir,
|
|
||||||
mock_progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert isinstance(result, DownloadModelStatus)
|
|
||||||
assert result.message == 'Successfully downloaded model.sft'
|
|
||||||
assert result.status == 'completed'
|
|
||||||
assert result.already_existed is False
|
|
||||||
|
|
||||||
# Check progress callback calls
|
|
||||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
|
||||||
|
|
||||||
# Check initial call
|
|
||||||
mock_progress_callback.assert_any_call(
|
|
||||||
'model.sft',
|
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check final call
|
|
||||||
mock_progress_callback.assert_any_call(
|
|
||||||
'model.sft',
|
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
|
||||||
assert os.path.exists(mock_file_path)
|
|
||||||
with open(mock_file_path, 'rb') as mock_file:
|
|
||||||
assert mock_file.read() == b''.join(chunks)
|
|
||||||
os.remove(mock_file_path)
|
|
||||||
|
|
||||||
# Verify request was made
|
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_download_model_url_request_failure(temp_dir):
|
|
||||||
# Mock dependencies
|
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
|
||||||
mock_progress_callback = AsyncMock()
|
|
||||||
|
|
||||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
|
||||||
|
|
||||||
# Mock the create_model_path function
|
|
||||||
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
|
||||||
patch('folder_paths.folder_names_and_paths', fake_paths):
|
|
||||||
# Call the function
|
|
||||||
result = await download_model(
|
|
||||||
mock_get,
|
|
||||||
'model.safetensors',
|
|
||||||
'http://example.com/model.safetensors',
|
|
||||||
'checkpoints',
|
|
||||||
temp_dir,
|
|
||||||
mock_progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the expected behavior
|
|
||||||
assert isinstance(result, DownloadModelStatus)
|
|
||||||
assert result.status == 'error'
|
|
||||||
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
|
||||||
assert result.already_existed is False
|
|
||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
|
||||||
mock_progress_callback.assert_any_call(
|
|
||||||
'model.safetensors',
|
|
||||||
DownloadModelStatus(
|
|
||||||
status=DownloadStatusType.PENDING,
|
|
||||||
progress_percentage=0,
|
|
||||||
message='Starting download of model.safetensors',
|
|
||||||
already_existed=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mock_progress_callback.assert_called_with(
|
|
||||||
'model.safetensors',
|
|
||||||
DownloadModelStatus(
|
|
||||||
status=DownloadStatusType.ERROR,
|
|
||||||
progress_percentage=0,
|
|
||||||
message='Failed to download model.safetensors. Status code: 404',
|
|
||||||
already_existed=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify that the get method was called with the correct URL
|
|
||||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
|
||||||
mock_make_request = AsyncMock()
|
|
||||||
mock_progress_callback = AsyncMock()
|
|
||||||
|
|
||||||
result = await download_model(
|
|
||||||
mock_make_request,
|
|
||||||
'model.sft',
|
|
||||||
'http://example.com/model.sft',
|
|
||||||
'../bad_path',
|
|
||||||
'../bad_path',
|
|
||||||
mock_progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert isinstance(result, DownloadModelStatus)
|
|
||||||
assert result.message.startswith('Invalid or unrecognized model directory')
|
|
||||||
assert result.status == 'error'
|
|
||||||
assert result.already_existed is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_download_model_invalid_folder_path():
|
|
||||||
mock_make_request = AsyncMock()
|
|
||||||
mock_progress_callback = AsyncMock()
|
|
||||||
|
|
||||||
result = await download_model(
|
|
||||||
mock_make_request,
|
|
||||||
'model.sft',
|
|
||||||
'http://example.com/model.sft',
|
|
||||||
'checkpoints',
|
|
||||||
'invalid_path',
|
|
||||||
mock_progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert isinstance(result, DownloadModelStatus)
|
|
||||||
assert result.message.startswith("Invalid folder path")
|
|
||||||
assert result.status == 'error'
|
|
||||||
assert result.already_existed is False
|
|
||||||
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
|
||||||
model_name = "model.safetensors"
|
|
||||||
folder_path = os.path.join(tmp_path, "mock_dir")
|
|
||||||
|
|
||||||
file_path = create_model_path(model_name, folder_path)
|
|
||||||
|
|
||||||
assert file_path == os.path.join(folder_path, "model.safetensors")
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Invalid model directory"):
|
|
||||||
create_model_path("../path_traversal.safetensors", folder_path)
|
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Invalid model directory"):
|
|
||||||
create_model_path("/etc/some_root_path", folder_path)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
|
||||||
file_path = tmp_path / "existing_model.sft"
|
|
||||||
file_path.touch() # Create an empty file
|
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.status == "completed"
|
|
||||||
assert result.message == "existing_model.sft already exists"
|
|
||||||
assert result.already_existed is True
|
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
|
||||||
"existing_model.sft",
|
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
|
||||||
file_path = tmp_path / "non_existing_model.sft"
|
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
mock_callback.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_track_download_progress_no_content_length(temp_dir):
|
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
|
||||||
mock_response.headers = {} # No Content-Length header
|
|
||||||
chunks = [b'a' * 500, b'b' * 500]
|
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
|
||||||
|
|
||||||
full_path = os.path.join(temp_dir, 'model.sft')
|
|
||||||
|
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, full_path, 'model.sft',
|
|
||||||
mock_callback, interval=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.status == "completed"
|
|
||||||
|
|
||||||
assert os.path.exists(full_path)
|
|
||||||
with open(full_path, 'rb') as f:
|
|
||||||
assert f.read() == b''.join(chunks)
|
|
||||||
os.remove(full_path)
|
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
|
||||||
mock_callback.assert_any_call(
|
|
||||||
'model.sft',
|
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_track_download_progress_interval(temp_dir):
|
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
|
||||||
chunks = [b'a' * 100] * 10
|
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
# Create a mock time function that returns incremental float values
|
|
||||||
mock_time = MagicMock()
|
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
|
||||||
|
|
||||||
full_path = os.path.join(temp_dir, 'model.sft')
|
|
||||||
|
|
||||||
with patch('time.time', mock_time):
|
|
||||||
await track_download_progress(
|
|
||||||
mock_response, full_path, 'model.sft',
|
|
||||||
mock_callback, interval=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert os.path.exists(full_path)
|
|
||||||
with open(full_path, 'rb') as f:
|
|
||||||
assert f.read() == b''.join(chunks)
|
|
||||||
os.remove(full_path)
|
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
|
||||||
|
|
||||||
# Verify the first and last calls
|
|
||||||
first_call = mock_callback.call_args_list[0]
|
|
||||||
assert first_call[0][1].status == "in_progress"
|
|
||||||
# Allow for some initial progress, but it should be less than 50%
|
|
||||||
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
|
||||||
|
|
||||||
last_call = mock_callback.call_args_list[-1]
|
|
||||||
assert last_call[0][1].status == "completed"
|
|
||||||
assert last_call[0][1].progress_percentage == 100
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
|
||||||
("valid_model.safetensors", True),
|
|
||||||
("valid_model.sft", True),
|
|
||||||
("valid model.safetensors", True), # Test with space
|
|
||||||
("UPPERCASE_MODEL.SAFETENSORS", True),
|
|
||||||
("model_with.multiple.dots.pt", False),
|
|
||||||
("", False), # Empty string
|
|
||||||
("../../../etc/passwd", False), # Path traversal attempt
|
|
||||||
("/etc/passwd", False), # Absolute path
|
|
||||||
("\\windows\\system32\\config\\sam", False), # Windows path
|
|
||||||
(".hidden_file.pt", False), # Hidden file
|
|
||||||
("invalid<char>.ckpt", False), # Invalid character
|
|
||||||
("invalid?.ckpt", False), # Another invalid character
|
|
||||||
("very" * 100 + ".safetensors", False), # Too long filename
|
|
||||||
("\nmodel_with_newline.pt", False), # Newline character
|
|
||||||
("model_with_emoji😊.pt", False), # Emoji in filename
|
|
||||||
])
|
|
||||||
def test_validate_filename(filename, expected):
|
|
||||||
assert validate_filename(filename) == expected
|
|
Loading…
Reference in New Issue