2024-08-13 19:48:52 +00:00
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
2024-09-24 07:50:45 +00:00
from folder_paths import folder_names_and_paths , get_folder_paths
2024-08-13 19:48:52 +00:00
import re
from typing import Callable , Any , Optional , Awaitable , Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType ( Enum ) :
PENDING = " pending "
IN_PROGRESS = " in_progress "
COMPLETED = " completed "
ERROR = " error "
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
@dataclass
class DownloadModelStatus ( ) :
status : str
progress_percentage : float
message : str
already_existed : bool = False
def __init__ ( self , status : DownloadStatusType , progress_percentage : float , message : str , already_existed : bool ) :
self . status = status . value # Store the string value of the Enum
self . progress_percentage = progress_percentage
self . message = message
self . already_existed = already_existed
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
def to_dict ( self ) - > Dict [ str , Any ] :
return {
" status " : self . status ,
" progress_percentage " : self . progress_percentage ,
" message " : self . message ,
" already_existed " : self . already_existed
}
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
async def download_model ( model_download_request : Callable [ [ str ] , Awaitable [ aiohttp . ClientResponse ] ] ,
2024-09-24 07:50:45 +00:00
model_name : str ,
model_url : str ,
model_directory : str ,
folder_path : str ,
2024-08-13 19:48:52 +00:00
progress_callback : Callable [ [ str , DownloadModelStatus ] , Awaitable [ Any ] ] ,
progress_interval : float = 1.0 ) - > DownloadModelStatus :
"""
Download a model file from a given URL into the models directory .
Args :
2024-09-24 07:50:45 +00:00
model_download_request ( Callable [ [ str ] , Awaitable [ aiohttp . ClientResponse ] ] ) :
2024-08-13 19:48:52 +00:00
A function that makes an HTTP request . This makes it easier to mock in unit tests .
2024-09-24 07:50:45 +00:00
model_name ( str ) :
2024-08-13 19:48:52 +00:00
The name of the model file to be downloaded . This will be the filename on disk .
2024-09-24 07:50:45 +00:00
model_url ( str ) :
2024-08-13 19:48:52 +00:00
The URL from which to download the model .
2024-09-24 07:50:45 +00:00
model_directory ( str ) :
The subdirectory within the main models directory where the model
2024-08-13 19:48:52 +00:00
should be saved ( e . g . , ' checkpoints ' , ' loras ' , etc . ) .
2024-09-24 07:50:45 +00:00
progress_callback ( Callable [ [ str , DownloadModelStatus ] , Awaitable [ Any ] ] ) :
2024-08-13 19:48:52 +00:00
An asynchronous function to call with progress updates .
2024-09-24 07:50:45 +00:00
folder_path ( str ) ;
Path to which model folder should be used as the root .
2024-08-13 19:48:52 +00:00
Returns :
DownloadModelStatus : The result of the download operation .
"""
2024-09-24 07:50:45 +00:00
if not validate_filename ( model_name ) :
2024-08-13 19:48:52 +00:00
return DownloadModelStatus (
2024-09-24 07:50:45 +00:00
DownloadStatusType . ERROR ,
2024-08-13 19:48:52 +00:00
0 ,
2024-09-24 07:50:45 +00:00
" Invalid model name " ,
2024-08-13 19:48:52 +00:00
False
)
2024-09-24 07:50:45 +00:00
if not model_directory in folder_names_and_paths :
2024-08-13 19:48:52 +00:00
return DownloadModelStatus (
2024-09-24 07:50:45 +00:00
DownloadStatusType . ERROR ,
2024-08-13 19:48:52 +00:00
0 ,
2024-09-24 07:50:45 +00:00
" Invalid or unrecognized model directory. model_directory must be a known model type (eg ' checkpoints ' ). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working. " ,
2024-08-13 19:48:52 +00:00
False
)
2024-09-24 07:50:45 +00:00
if not folder_path in get_folder_paths ( model_directory ) :
return DownloadModelStatus (
DownloadStatusType . ERROR ,
0 ,
f " Invalid folder path ' { folder_path } ' , does not match the list of known directories ( { get_folder_paths ( model_directory ) } ). If you ' re seeing this in the downloader UI, you may need to refresh the page. " ,
False
)
file_path = create_model_path ( model_name , folder_path )
existing_file = await check_file_exists ( file_path , model_name , progress_callback )
2024-08-13 19:48:52 +00:00
if existing_file :
return existing_file
try :
2024-09-24 07:50:45 +00:00
logging . info ( f " Downloading { model_name } from { model_url } " )
2024-08-13 19:48:52 +00:00
status = DownloadModelStatus ( DownloadStatusType . PENDING , 0 , f " Starting download of { model_name } " , False )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
response = await model_download_request ( model_url )
if response . status != 200 :
error_message = f " Failed to download { model_name } . Status code: { response . status } "
logging . error ( error_message )
status = DownloadModelStatus ( DownloadStatusType . ERROR , 0 , error_message , False )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
return DownloadModelStatus ( DownloadStatusType . ERROR , 0 , error_message , False )
2024-09-24 07:50:45 +00:00
return await track_download_progress ( response , file_path , model_name , progress_callback , progress_interval )
2024-08-13 19:48:52 +00:00
except Exception as e :
logging . error ( f " Error in downloading model: { e } " )
2024-09-24 07:50:45 +00:00
return await handle_download_error ( e , model_name , progress_callback )
2024-08-13 19:48:52 +00:00
2024-09-24 07:50:45 +00:00
def create_model_path ( model_name : str , folder_path : str ) - > tuple [ str , str ] :
os . makedirs ( folder_path , exist_ok = True )
file_path = os . path . join ( folder_path , model_name )
2024-08-13 19:48:52 +00:00
# Ensure the resulting path is still within the base directory
abs_file_path = os . path . abspath ( file_path )
2024-09-24 07:50:45 +00:00
abs_base_dir = os . path . abspath ( folder_path )
2024-08-13 19:48:52 +00:00
if os . path . commonprefix ( [ abs_file_path , abs_base_dir ] ) != abs_base_dir :
2024-09-24 07:50:45 +00:00
raise Exception ( f " Invalid model directory: { folder_path } / { model_name } " )
2024-08-13 19:48:52 +00:00
2024-09-24 07:50:45 +00:00
return file_path
2024-08-13 19:48:52 +00:00
2024-09-24 07:50:45 +00:00
async def check_file_exists ( file_path : str ,
model_name : str ,
progress_callback : Callable [ [ str , DownloadModelStatus ] , Awaitable [ Any ] ]
) - > Optional [ DownloadModelStatus ] :
2024-08-13 19:48:52 +00:00
if os . path . exists ( file_path ) :
status = DownloadModelStatus ( DownloadStatusType . COMPLETED , 100 , f " { model_name } already exists " , True )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
return status
return None
2024-09-24 07:50:45 +00:00
async def track_download_progress ( response : aiohttp . ClientResponse ,
file_path : str ,
model_name : str ,
progress_callback : Callable [ [ str , DownloadModelStatus ] , Awaitable [ Any ] ] ,
2024-08-13 19:48:52 +00:00
interval : float = 1.0 ) - > DownloadModelStatus :
try :
total_size = int ( response . headers . get ( ' Content-Length ' , 0 ) )
downloaded = 0
last_update_time = time . time ( )
async def update_progress ( ) :
nonlocal last_update_time
progress = ( downloaded / total_size ) * 100 if total_size > 0 else 0
status = DownloadModelStatus ( DownloadStatusType . IN_PROGRESS , progress , f " Downloading { model_name } " , False )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
last_update_time = time . time ( )
2024-09-24 07:50:45 +00:00
temp_file_path = file_path + ' .tmp '
with open ( temp_file_path , ' wb ' ) as f :
2024-08-13 19:48:52 +00:00
chunk_iterator = response . content . iter_chunked ( 8192 )
while True :
try :
chunk = await chunk_iterator . __anext__ ( )
except StopAsyncIteration :
break
f . write ( chunk )
downloaded + = len ( chunk )
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
if time . time ( ) - last_update_time > = interval :
await update_progress ( )
2024-09-24 07:50:45 +00:00
os . rename ( temp_file_path , file_path )
2024-08-13 19:48:52 +00:00
await update_progress ( )
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
logging . info ( f " Successfully downloaded { model_name } . Total downloaded: { downloaded } " )
status = DownloadModelStatus ( DownloadStatusType . COMPLETED , 100 , f " Successfully downloaded { model_name } " , False )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
return status
except Exception as e :
logging . error ( f " Error in track_download_progress: { e } " )
logging . error ( traceback . format_exc ( ) )
2024-09-24 07:50:45 +00:00
return await handle_download_error ( e , model_name , progress_callback )
2024-08-13 19:48:52 +00:00
2024-09-24 07:50:45 +00:00
async def handle_download_error ( e : Exception ,
model_name : str ,
progress_callback : Callable [ [ str , DownloadModelStatus ] , Any ]
) - > DownloadModelStatus :
2024-08-13 19:48:52 +00:00
error_message = f " Error downloading { model_name } : { str ( e ) } "
status = DownloadModelStatus ( DownloadStatusType . ERROR , 0 , error_message , False )
2024-09-24 07:50:45 +00:00
await progress_callback ( model_name , status )
2024-08-13 19:48:52 +00:00
return status
def validate_filename ( filename : str ) - > bool :
"""
Validate a filename to ensure it ' s safe and doesn ' t contain any path traversal attempts .
2024-09-24 07:50:45 +00:00
2024-08-13 19:48:52 +00:00
Args :
filename ( str ) : The filename to validate
Returns :
bool : True if the filename is valid , False otherwise
"""
if not filename . lower ( ) . endswith ( ( ' .sft ' , ' .safetensors ' ) ) :
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename . strip ( ) :
return False
# Check for any directory traversal attempts or invalid characters
if any ( char in filename for char in [ ' .. ' , ' / ' , ' \\ ' , ' \n ' , ' \r ' , ' \t ' , ' \0 ' ] ) :
return False
# Check if the filename starts with a dot (hidden file)
if filename . startswith ( ' . ' ) :
return False
# Use a whitelist of allowed characters
if not re . match ( r ' ^[a-zA-Z0-9_ \ -. ]+$ ' , filename ) :
return False
# Ensure the filename isn't too long
if len ( filename ) > 255 :
return False
return True