2023-01-03 06:53:32 +00:00
import os
import sys
2023-03-13 19:34:05 +00:00
import shutil
2023-02-28 00:43:55 +00:00
2023-01-03 06:53:32 +00:00
import threading
2023-02-12 15:53:48 +00:00
import asyncio
2023-01-03 06:53:32 +00:00
2023-02-16 18:19:26 +00:00
if os . name == " nt " :
import logging
logging . getLogger ( " xformers " ) . addFilter ( lambda record : ' A matching Triton is not available ' not in record . getMessage ( ) )
2023-02-08 03:12:56 +00:00
if __name__ == " __main__ " :
if ' --help ' in sys . argv :
print ( " Valid Command line Arguments: " )
print ( " \t --listen \t \t \t Listen on 0.0.0.0 so the UI can be accessed from other computers. " )
print ( " \t --port 8188 \t \t \t Set the listen port. " )
print ( " \t --dont-upcast-attention \t \t Disable upcasting of attention \n \t \t \t \t \t can boost speed but increase the chances of black images. \n " )
print ( " \t --use-split-cross-attention \t Use the split cross attention optimization instead of the sub-quadratic one. \n \t \t \t \t \t Ignored when xformers is used. " )
2023-03-13 16:25:19 +00:00
print ( " \t --use-pytorch-cross-attention \t Use the new pytorch 2.0 cross attention function. " )
2023-03-13 15:36:48 +00:00
print ( " \t --disable-xformers \t \t disables xformers " )
2023-02-08 03:12:56 +00:00
print ( )
2023-02-18 02:32:27 +00:00
print ( " \t --highvram \t \t \t By default models will be unloaded to CPU memory after being used. \n \t \t \t \t \t This option keeps them in GPU memory. \n " )
2023-02-10 05:47:56 +00:00
print ( " \t --normalvram \t \t \t Used to force normal vram use if lowvram gets automatically enabled. " )
2023-02-08 16:37:10 +00:00
print ( " \t --lowvram \t \t \t Split the unet in parts to use less vram. " )
print ( " \t --novram \t \t \t When lowvram isn ' t enough. " )
print ( )
2023-03-06 15:50:50 +00:00
print ( " \t --cpu \t \t \t To use the CPU for everything (slow). " )
2023-02-08 03:12:56 +00:00
exit ( )
2023-02-21 19:29:49 +00:00
if ' --dont-upcast-attention ' in sys . argv :
print ( " disabling upcasting of attention " )
os . environ [ ' ATTN_PRECISION ' ] = " fp16 "
2023-01-29 18:12:22 +00:00
2023-03-13 15:36:48 +00:00
import execution
import server
2023-03-18 06:52:43 +00:00
import folder_paths
import yaml
2023-03-13 15:36:48 +00:00
2023-02-21 19:29:49 +00:00
def prompt_worker ( q , server ) :
2023-02-28 00:43:55 +00:00
e = execution . PromptExecutor ( server )
2023-01-03 06:53:32 +00:00
while True :
2023-02-02 03:33:10 +00:00
item , item_id = q . get ( )
2023-01-03 06:53:32 +00:00
e . execute ( item [ - 2 ] , item [ - 1 ] )
2023-02-23 20:12:57 +00:00
q . task_done ( item_id , e . outputs )
2023-01-03 06:53:32 +00:00
2023-03-12 19:44:16 +00:00
async def run ( server , address = ' ' , port = 8188 , verbose = True , call_on_start = None ) :
await asyncio . gather ( server . start ( address , port , verbose , call_on_start ) , server . publish_loop ( ) )
2023-01-03 06:53:32 +00:00
2023-02-21 19:29:49 +00:00
def hijack_progress ( server ) :
from tqdm . auto import tqdm
orig_func = getattr ( tqdm , " update " )
def wrapped_func ( * args , * * kwargs ) :
pbar = args [ 0 ]
v = orig_func ( * args , * * kwargs )
server . send_sync ( " progress " , { " value " : pbar . n , " max " : pbar . total } , server . client_id )
return v
setattr ( tqdm , " update " , wrapped_func )
2023-01-03 06:53:32 +00:00
2023-03-13 19:34:05 +00:00
def cleanup_temp ( ) :
temp_dir = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " temp " )
if os . path . exists ( temp_dir ) :
2023-03-14 22:07:09 +00:00
shutil . rmtree ( temp_dir , ignore_errors = True )
2023-03-13 19:34:05 +00:00
2023-03-18 06:52:43 +00:00
def load_extra_path_config ( yaml_path ) :
with open ( yaml_path , ' r ' ) as stream :
config = yaml . safe_load ( stream )
for c in config :
conf = config [ c ]
if conf is None :
continue
base_path = None
if " base_path " in conf :
base_path = conf . pop ( " base_path " )
for x in conf :
for y in conf [ x ] . split ( " \n " ) :
if len ( y ) == 0 :
continue
full_path = y
if base_path is not None :
full_path = os . path . join ( base_path , full_path )
print ( " Adding extra search path " , x , full_path )
folder_paths . add_model_folder_path ( x , full_path )
2023-01-03 06:53:32 +00:00
if __name__ == " __main__ " :
2023-03-13 19:34:05 +00:00
cleanup_temp ( )
2023-02-12 15:53:48 +00:00
loop = asyncio . new_event_loop ( )
asyncio . set_event_loop ( loop )
2023-02-21 19:29:49 +00:00
server = server . PromptServer ( loop )
2023-02-28 00:43:55 +00:00
q = execution . PromptQueue ( server )
2023-02-12 15:53:48 +00:00
2023-02-21 19:29:49 +00:00
hijack_progress ( server )
threading . Thread ( target = prompt_worker , daemon = True , args = ( q , server , ) ) . start ( )
2023-02-04 17:01:53 +00:00
if ' --listen ' in sys . argv :
address = ' 0.0.0.0 '
else :
address = ' 127.0.0.1 '
2023-02-08 02:57:17 +00:00
2023-02-26 03:49:22 +00:00
dont_print = False
if ' --dont-print-server ' in sys . argv :
dont_print = True
2023-03-18 06:52:43 +00:00
extra_model_paths_config_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " extra_model_paths.yaml " )
if os . path . isfile ( extra_model_paths_config_path ) :
load_extra_path_config ( extra_model_paths_config_path )
if ' --extra-model-paths-config ' in sys . argv :
indices = [ ( i + 1 ) for i in range ( len ( sys . argv ) - 1 ) if sys . argv [ i ] == ' --extra-model-paths-config ' ]
for i in indices :
load_extra_path_config ( sys . argv [ i ] )
2023-02-08 02:57:17 +00:00
port = 8188
try :
p_index = sys . argv . index ( ' --port ' )
port = int ( sys . argv [ p_index + 1 ] )
except :
pass
2023-03-15 03:02:57 +00:00
if ' --quick-test-for-ci ' in sys . argv :
exit ( 0 )
2023-03-12 19:44:16 +00:00
call_on_start = None
if " --windows-standalone-build " in sys . argv :
def startup_server ( address , port ) :
import webbrowser
webbrowser . open ( " http:// {} : {} " . format ( address , port ) )
call_on_start = startup_server
2023-02-21 19:29:49 +00:00
if os . name == " nt " :
try :
2023-03-12 19:44:16 +00:00
loop . run_until_complete ( run ( server , address = address , port = port , verbose = not dont_print , call_on_start = call_on_start ) )
2023-02-21 19:29:49 +00:00
except KeyboardInterrupt :
pass
else :
2023-03-12 19:44:16 +00:00
loop . run_until_complete ( run ( server , address = address , port = port , verbose = not dont_print , call_on_start = call_on_start ) )
2023-01-03 06:53:32 +00:00
2023-03-13 19:34:05 +00:00
cleanup_temp ( )