Remove internal model download endpoint

This commit is contained in:
huchenlei 2024-10-30 16:13:57 -04:00
parent daa1565b93
commit 0076783b35
5 changed files with 0 additions and 605 deletions

View File

@ -9,7 +9,6 @@ class InternalRoutes:
The top level web router for internal routes: /internal/* The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only. The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information. Check README.md for more information.
''' '''
def __init__(self): def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef() self.routes: web.RouteTableDef = web.RouteTableDef()

View File

@ -1,2 +0,0 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@ -1,234 +0,0 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(model_name, status)
last_update_time = time.time()
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return status
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

View File

@ -29,7 +29,6 @@ import comfy.model_management
import node_helpers import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
@ -676,36 +675,6 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete) self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200) return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
async def setup(self): async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout timeout = aiohttp.ClientTimeout(total=None) # no timeout

View File

@ -1,337 +0,0 @@
import pytest
import tempfile
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try:
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
time_values = itertools.count(0, 0.1)
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
mock_file_path = os.path.join(temp_dir, 'model.sft')
assert os.path.exists(mock_file_path)
with open(mock_file_path, 'rb') as mock_file:
assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors',
already_existed=False
)
)
mock_progress_callback.assert_called_with(
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404',
already_existed=False
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error'
assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
def test_create_model_path(tmp_path, monkeypatch):
model_name = "model.safetensors"
folder_path = os.path.join(tmp_path, "mock_dir")
file_path = create_model_path(model_name, folder_path)
assert file_path == os.path.join(folder_path, "model.safetensors")
assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=0.1
)
assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
full_path = os.path.join(temp_dir, 'model.sft')
with patch('time.time', mock_time):
await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=1.0
)
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", False),
("", False), # Empty string
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
])
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected