Add model downloading endpoint. (#4248)

* Add model downloading endpoint.

* Move client session init to async function.

* Break up large function.

* Send "download_progress" as websocket event.

* Fixed

* Fixed.

* Use async mock.

* Move server set up to right before run call.

* Validate that model subdirectory cannot contain relative paths.

* Add download_model test checking for invalid paths.

* Remove DS_Store.

* Consolidate DownloadStatus and DownloadModelResult

* Add progress_interval as an optional parameter.

* Use tuple type from annotations.

* Use pydantic.

* Update comment.

* Revert "Use pydantic."

This reverts commit 7461e8eb00.

* Add new line.

* Add newline EOF.

* Validate model filename as well.

* Add comment to not reply on internal.

* Restrict downloading to safetensor files only.
This commit is contained in:
Robin Huang 2024-08-13 12:48:52 -07:00 committed by GitHub
parent 34608de2e9
commit 3e52e0364c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 599 additions and 2 deletions

View File

@ -261,6 +261,7 @@ if __name__ == "__main__":
call_on_start = startup_server call_on_start = startup_server
try: try:
loop.run_until_complete(server.setup())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")

View File

@ -0,0 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename

View File

@ -0,0 +1,240 @@
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import models_dir
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
return existing_file
try:
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

View File

@ -12,7 +12,6 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import hashlib
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -28,7 +27,8 @@ import comfy.model_management
import node_helpers import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -76,6 +76,7 @@ class PromptServer():
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0 self.number = 0
middlewares = [cache_control] middlewares = [cache_control]
@ -560,6 +561,36 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", status.to_dict())
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes) self.user_manager.add_routes(self.routes)

View File

@ -0,0 +1,321 @@
import pytest
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try:
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure():
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'mock_directory',
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors',
already_existed=False
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404',
already_existed=False
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory'
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"test/existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
)
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", False),
("", False), # Empty string
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
])
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected

View File

@ -1 +1,3 @@
pytest>=7.8.0 pytest>=7.8.0
pytest-aiohttp
pytest-asyncio