2023-01-03 06:53:32 +00:00
import os
import sys
import copy
import json
import threading
2023-02-02 03:33:10 +00:00
import heapq
2023-01-03 06:53:32 +00:00
import traceback
2023-02-08 03:12:56 +00:00
if __name__ == " __main__ " :
if ' --help ' in sys . argv :
print ( " Valid Command line Arguments: " )
print ( " \t --listen \t \t \t Listen on 0.0.0.0 so the UI can be accessed from other computers. " )
print ( " \t --port 8188 \t \t \t Set the listen port. " )
print ( " \t --dont-upcast-attention \t \t Disable upcasting of attention \n \t \t \t \t \t can boost speed but increase the chances of black images. \n " )
print ( " \t --use-split-cross-attention \t Use the split cross attention optimization instead of the sub-quadratic one. \n \t \t \t \t \t Ignored when xformers is used. " )
print ( )
2023-02-08 16:37:10 +00:00
print ( " \t --lowvram \t \t \t Split the unet in parts to use less vram. " )
print ( " \t --novram \t \t \t When lowvram isn ' t enough. " )
print ( )
2023-02-08 03:12:56 +00:00
exit ( )
2023-01-29 18:12:22 +00:00
if ' --dont-upcast-attention ' in sys . argv :
print ( " disabling upcasting of attention " )
os . environ [ ' ATTN_PRECISION ' ] = " fp16 "
2023-01-03 06:53:32 +00:00
import torch
import nodes
2023-01-23 02:42:22 +00:00
def get_input_data ( inputs , class_def , outputs = { } , prompt = { } , extra_data = { } ) :
valid_inputs = class_def . INPUT_TYPES ( )
2023-01-03 06:53:32 +00:00
input_data_all = { }
for x in inputs :
input_data = inputs [ x ]
if isinstance ( input_data , list ) :
input_unique_id = input_data [ 0 ]
output_index = input_data [ 1 ]
obj = outputs [ input_unique_id ] [ output_index ]
input_data_all [ x ] = obj
else :
if ( " required " in valid_inputs and x in valid_inputs [ " required " ] ) or ( " optional " in valid_inputs and x in valid_inputs [ " optional " ] ) :
input_data_all [ x ] = input_data
if " hidden " in valid_inputs :
h = valid_inputs [ " hidden " ]
for x in h :
if h [ x ] == " PROMPT " :
input_data_all [ x ] = prompt
if h [ x ] == " EXTRA_PNGINFO " :
if " extra_pnginfo " in extra_data :
input_data_all [ x ] = extra_data [ ' extra_pnginfo ' ]
2023-01-23 02:42:22 +00:00
return input_data_all
def recursive_execute ( prompt , outputs , current_item , extra_data = { } ) :
unique_id = current_item
inputs = prompt [ unique_id ] [ ' inputs ' ]
class_type = prompt [ unique_id ] [ ' class_type ' ]
class_def = nodes . NODE_CLASS_MAPPINGS [ class_type ]
if unique_id in outputs :
return [ ]
executed = [ ]
for x in inputs :
input_data = inputs [ x ]
if isinstance ( input_data , list ) :
input_unique_id = input_data [ 0 ]
output_index = input_data [ 1 ]
if input_unique_id not in outputs :
executed + = recursive_execute ( prompt , outputs , input_unique_id , extra_data )
input_data_all = get_input_data ( inputs , class_def , outputs , prompt , extra_data )
obj = class_def ( )
2023-01-03 06:53:32 +00:00
outputs [ unique_id ] = getattr ( obj , obj . FUNCTION ) ( * * input_data_all )
return executed + [ unique_id ]
2023-01-23 03:10:13 +00:00
def recursive_will_execute ( prompt , outputs , current_item ) :
unique_id = current_item
inputs = prompt [ unique_id ] [ ' inputs ' ]
will_execute = [ ]
if unique_id in outputs :
return [ ]
for x in inputs :
input_data = inputs [ x ]
if isinstance ( input_data , list ) :
input_unique_id = input_data [ 0 ]
output_index = input_data [ 1 ]
if input_unique_id not in outputs :
will_execute + = recursive_will_execute ( prompt , outputs , input_unique_id )
return will_execute + [ unique_id ]
2023-01-03 06:53:32 +00:00
def recursive_output_delete_if_changed ( prompt , old_prompt , outputs , current_item ) :
unique_id = current_item
inputs = prompt [ unique_id ] [ ' inputs ' ]
class_type = prompt [ unique_id ] [ ' class_type ' ]
2023-01-23 02:42:22 +00:00
class_def = nodes . NODE_CLASS_MAPPINGS [ class_type ]
is_changed_old = ' '
is_changed = ' '
if hasattr ( class_def , ' IS_CHANGED ' ) :
2023-02-04 21:08:29 +00:00
if unique_id in old_prompt and ' is_changed ' in old_prompt [ unique_id ] :
is_changed_old = old_prompt [ unique_id ] [ ' is_changed ' ]
2023-01-23 02:42:22 +00:00
if ' is_changed ' not in prompt [ unique_id ] :
input_data_all = get_input_data ( inputs , class_def )
is_changed = class_def . IS_CHANGED ( * * input_data_all )
prompt [ unique_id ] [ ' is_changed ' ] = is_changed
else :
is_changed = prompt [ unique_id ] [ ' is_changed ' ]
2023-01-03 06:53:32 +00:00
if unique_id not in outputs :
return True
to_delete = False
2023-01-23 02:42:22 +00:00
if is_changed != is_changed_old :
to_delete = True
elif unique_id not in old_prompt :
2023-01-03 06:53:32 +00:00
to_delete = True
elif inputs == old_prompt [ unique_id ] [ ' inputs ' ] :
for x in inputs :
input_data = inputs [ x ]
if isinstance ( input_data , list ) :
input_unique_id = input_data [ 0 ]
output_index = input_data [ 1 ]
if input_unique_id in outputs :
to_delete = recursive_output_delete_if_changed ( prompt , old_prompt , outputs , input_unique_id )
else :
to_delete = True
if to_delete :
break
else :
to_delete = True
if to_delete :
print ( " deleted " , unique_id )
d = outputs . pop ( unique_id )
del d
return to_delete
class PromptExecutor :
def __init__ ( self ) :
self . outputs = { }
self . old_prompt = { }
def execute ( self , prompt , extra_data = { } ) :
with torch . no_grad ( ) :
for x in prompt :
recursive_output_delete_if_changed ( prompt , self . old_prompt , self . outputs , x )
current_outputs = set ( self . outputs . keys ( ) )
executed = [ ]
try :
2023-01-23 03:10:13 +00:00
to_execute = [ ]
2023-01-03 06:53:32 +00:00
for x in prompt :
2023-01-23 03:10:13 +00:00
class_ = nodes . NODE_CLASS_MAPPINGS [ prompt [ x ] [ ' class_type ' ] ]
if hasattr ( class_ , ' OUTPUT_NODE ' ) :
to_execute + = [ ( 0 , x ) ]
while len ( to_execute ) > 0 :
2023-02-02 03:33:10 +00:00
#always execute the output that depends on the least amount of unexecuted nodes first
2023-01-23 03:10:13 +00:00
to_execute = sorted ( list ( map ( lambda a : ( len ( recursive_will_execute ( prompt , self . outputs , a [ - 1 ] ) ) , a [ - 1 ] ) , to_execute ) ) )
x = to_execute . pop ( 0 ) [ - 1 ]
2023-01-03 06:53:32 +00:00
class_ = nodes . NODE_CLASS_MAPPINGS [ prompt [ x ] [ ' class_type ' ] ]
if hasattr ( class_ , ' OUTPUT_NODE ' ) :
if class_ . OUTPUT_NODE == True :
valid = False
try :
m = validate_inputs ( prompt , x )
valid = m [ 0 ]
except :
valid = False
if valid :
executed + = recursive_execute ( prompt , self . outputs , x , extra_data )
2023-01-23 03:10:13 +00:00
2023-01-03 06:53:32 +00:00
except Exception as e :
print ( traceback . format_exc ( ) )
to_delete = [ ]
for o in self . outputs :
if o not in current_outputs :
to_delete + = [ o ]
if o in self . old_prompt :
d = self . old_prompt . pop ( o )
del d
for o in to_delete :
d = self . outputs . pop ( o )
del d
else :
executed = set ( executed )
for x in executed :
self . old_prompt [ x ] = copy . deepcopy ( prompt [ x ] )
def validate_inputs ( prompt , item ) :
unique_id = item
inputs = prompt [ unique_id ] [ ' inputs ' ]
class_type = prompt [ unique_id ] [ ' class_type ' ]
obj_class = nodes . NODE_CLASS_MAPPINGS [ class_type ]
class_inputs = obj_class . INPUT_TYPES ( )
required_inputs = class_inputs [ ' required ' ]
for x in required_inputs :
if x not in inputs :
return ( False , " Required input is missing. {} , {} " . format ( class_type , x ) )
val = inputs [ x ]
info = required_inputs [ x ]
type_input = info [ 0 ]
if isinstance ( val , list ) :
if len ( val ) != 2 :
return ( False , " Bad Input. {} , {} " . format ( class_type , x ) )
o_id = val [ 0 ]
o_class_type = prompt [ o_id ] [ ' class_type ' ]
r = nodes . NODE_CLASS_MAPPINGS [ o_class_type ] . RETURN_TYPES
if r [ val [ 1 ] ] != type_input :
return ( False , " Return type mismatch. {} , {} " . format ( class_type , x ) )
r = validate_inputs ( prompt , o_id )
if r [ 0 ] == False :
return r
else :
if type_input == " INT " :
val = int ( val )
inputs [ x ] = val
if type_input == " FLOAT " :
val = float ( val )
inputs [ x ] = val
if type_input == " STRING " :
val = str ( val )
inputs [ x ] = val
if len ( info ) > 1 :
if " min " in info [ 1 ] and val < info [ 1 ] [ " min " ] :
return ( False , " Value smaller than min. {} , {} " . format ( class_type , x ) )
if " max " in info [ 1 ] and val > info [ 1 ] [ " max " ] :
return ( False , " Value bigger than max. {} , {} " . format ( class_type , x ) )
if isinstance ( type_input , list ) :
if val not in type_input :
2023-01-27 04:30:29 +00:00
return ( False , " Value not in list. {} , {} : {} not in {} " . format ( class_type , x , val , type_input ) )
2023-01-03 06:53:32 +00:00
return ( True , " " )
def validate_prompt ( prompt ) :
outputs = set ( )
for x in prompt :
class_ = nodes . NODE_CLASS_MAPPINGS [ prompt [ x ] [ ' class_type ' ] ]
if hasattr ( class_ , ' OUTPUT_NODE ' ) and class_ . OUTPUT_NODE == True :
outputs . add ( x )
if len ( outputs ) == 0 :
return ( False , " Prompt has no outputs " )
good_outputs = set ( )
2023-01-27 04:30:29 +00:00
errors = [ ]
2023-01-03 06:53:32 +00:00
for o in outputs :
valid = False
reason = " "
try :
m = validate_inputs ( prompt , o )
valid = m [ 0 ]
reason = m [ 1 ]
except :
valid = False
reason = " Parsing error "
if valid == True :
good_outputs . add ( x )
else :
print ( " Failed to validate prompt for output {} {} " . format ( o , reason ) )
print ( " output will be ignored " )
2023-01-27 04:30:29 +00:00
errors + = [ ( o , reason ) ]
2023-01-03 06:53:32 +00:00
if len ( good_outputs ) == 0 :
2023-01-27 04:30:29 +00:00
errors_list = " \n " . join ( map ( lambda a : " {} " . format ( a [ 1 ] ) , errors ) )
return ( False , " Prompt has no properly connected outputs \n {} " . format ( errors_list ) )
2023-01-03 06:53:32 +00:00
return ( True , " " )
def prompt_worker ( q ) :
e = PromptExecutor ( )
while True :
2023-02-02 03:33:10 +00:00
item , item_id = q . get ( )
2023-01-03 06:53:32 +00:00
e . execute ( item [ - 2 ] , item [ - 1 ] )
2023-02-02 03:33:10 +00:00
q . task_done ( item_id )
2023-01-03 06:53:32 +00:00
2023-02-02 03:33:10 +00:00
class PromptQueue :
def __init__ ( self ) :
self . mutex = threading . RLock ( )
self . not_empty = threading . Condition ( self . mutex )
self . task_counter = 0
self . queue = [ ]
self . currently_running = { }
def put ( self , item ) :
with self . mutex :
heapq . heappush ( self . queue , item )
self . not_empty . notify ( )
def get ( self ) :
with self . not_empty :
while len ( self . queue ) == 0 :
self . not_empty . wait ( )
item = heapq . heappop ( self . queue )
i = self . task_counter
self . currently_running [ i ] = copy . deepcopy ( item )
self . task_counter + = 1
return ( item , i )
def task_done ( self , item_id ) :
with self . mutex :
self . currently_running . pop ( item_id )
def get_current_queue ( self ) :
with self . mutex :
out = [ ]
for x in self . currently_running . values ( ) :
out + = [ x ]
return ( out , copy . deepcopy ( self . queue ) )
def get_tasks_remaining ( self ) :
with self . mutex :
return len ( self . queue ) + len ( self . currently_running )
def wipe_queue ( self ) :
with self . mutex :
self . queue = [ ]
def delete_queue_item ( self , function ) :
with self . mutex :
for x in range ( len ( self . queue ) ) :
if function ( self . queue [ x ] ) :
if len ( self . queue ) == 1 :
self . wipe_queue ( )
else :
self . queue . pop ( x )
heapq . heapify ( self . queue )
return True
return False
2023-01-03 06:53:32 +00:00
from http . server import BaseHTTPRequestHandler , HTTPServer
class PromptServer ( BaseHTTPRequestHandler ) :
def _set_headers ( self , code = 200 , ct = ' text/html ' ) :
self . send_response ( code )
self . send_header ( ' Content-type ' , ct )
self . end_headers ( )
def log_message ( self , format , * args ) :
pass
def do_GET ( self ) :
if self . path == " /prompt " :
self . _set_headers ( ct = ' application/json ' )
prompt_info = { }
exec_info = { }
2023-02-02 03:33:10 +00:00
exec_info [ ' queue_remaining ' ] = self . server . prompt_queue . get_tasks_remaining ( )
2023-01-03 06:53:32 +00:00
prompt_info [ ' exec_info ' ] = exec_info
self . wfile . write ( json . dumps ( prompt_info ) . encode ( ' utf-8 ' ) )
2023-02-02 03:33:10 +00:00
elif self . path == " /queue " :
self . _set_headers ( ct = ' application/json ' )
queue_info = { }
current_queue = self . server . prompt_queue . get_current_queue ( )
queue_info [ ' queue_running ' ] = current_queue [ 0 ]
queue_info [ ' queue_pending ' ] = current_queue [ 1 ]
self . wfile . write ( json . dumps ( queue_info ) . encode ( ' utf-8 ' ) )
2023-01-03 06:53:32 +00:00
elif self . path == " /object_info " :
self . _set_headers ( ct = ' application/json ' )
out = { }
for x in nodes . NODE_CLASS_MAPPINGS :
obj_class = nodes . NODE_CLASS_MAPPINGS [ x ]
info = { }
info [ ' input ' ] = obj_class . INPUT_TYPES ( )
info [ ' output ' ] = obj_class . RETURN_TYPES
info [ ' name ' ] = x #TODO
info [ ' description ' ] = ' '
2023-01-26 06:26:28 +00:00
info [ ' category ' ] = ' sd '
if hasattr ( obj_class , ' CATEGORY ' ) :
info [ ' category ' ] = obj_class . CATEGORY
2023-01-03 06:53:32 +00:00
out [ x ] = info
self . wfile . write ( json . dumps ( out ) . encode ( ' utf-8 ' ) )
elif self . path [ 1 : ] in os . listdir ( self . server . server_dir ) :
2023-01-25 01:18:16 +00:00
if self . path [ 1 : ] . endswith ( ' .css ' ) :
self . _set_headers ( ct = ' text/css ' )
elif self . path [ 1 : ] . endswith ( ' .js ' ) :
self . _set_headers ( ct = ' text/javascript ' )
else :
self . _set_headers ( )
2023-01-03 06:53:32 +00:00
with open ( os . path . join ( self . server . server_dir , self . path [ 1 : ] ) , " rb " ) as f :
self . wfile . write ( f . read ( ) )
else :
self . _set_headers ( )
with open ( os . path . join ( self . server . server_dir , " index.html " ) , " rb " ) as f :
self . wfile . write ( f . read ( ) )
def do_HEAD ( self ) :
self . _set_headers ( )
def do_POST ( self ) :
resp_code = 200
out_string = " "
if self . path == " /prompt " :
print ( " got prompt " )
2023-02-02 03:33:10 +00:00
data_string = self . rfile . read ( int ( self . headers [ ' Content-Length ' ] ) )
json_data = json . loads ( data_string )
2023-01-03 06:53:32 +00:00
if " number " in json_data :
number = float ( json_data [ ' number ' ] )
else :
number = self . server . number
2023-02-02 03:33:10 +00:00
if " front " in json_data :
if json_data [ ' front ' ] :
number = - number
2023-01-03 06:53:32 +00:00
self . server . number + = 1
if " prompt " in json_data :
prompt = json_data [ " prompt " ]
valid = validate_prompt ( prompt )
extra_data = { }
if " extra_data " in json_data :
extra_data = json_data [ " extra_data " ]
if valid [ 0 ] :
self . server . prompt_queue . put ( ( number , id ( prompt ) , prompt , extra_data ) )
else :
resp_code = 400
out_string = valid [ 1 ]
print ( " invalid prompt: " , valid [ 1 ] )
2023-02-02 03:33:10 +00:00
elif self . path == " /queue " :
data_string = self . rfile . read ( int ( self . headers [ ' Content-Length ' ] ) )
json_data = json . loads ( data_string )
if " clear " in json_data :
if json_data [ " clear " ] :
self . server . prompt_queue . wipe_queue ( )
if " delete " in json_data :
to_delete = json_data [ ' delete ' ]
for id_to_delete in to_delete :
delete_func = lambda a : a [ 1 ] == int ( id_to_delete )
self . server . prompt_queue . delete_queue_item ( delete_func )
2023-01-03 06:53:32 +00:00
self . _set_headers ( code = resp_code )
self . end_headers ( )
self . wfile . write ( out_string . encode ( ' utf8 ' ) )
return
def run ( prompt_queue , address = ' ' , port = 8188 ) :
server_address = ( address , port )
httpd = HTTPServer ( server_address , PromptServer )
httpd . server_dir = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " webshit " )
httpd . prompt_queue = prompt_queue
httpd . number = 0
if server_address [ 0 ] == ' ' :
addr = ' 0.0.0.0 '
else :
addr = server_address [ 0 ]
print ( " Starting server \n " )
print ( " To see the GUI go to: http:// {} : {} " . format ( addr , server_address [ 1 ] ) )
httpd . serve_forever ( )
if __name__ == " __main__ " :
2023-02-02 03:33:10 +00:00
q = PromptQueue ( )
2023-01-03 06:53:32 +00:00
threading . Thread ( target = prompt_worker , daemon = True , args = ( q , ) ) . start ( )
2023-02-04 17:01:53 +00:00
if ' --listen ' in sys . argv :
address = ' 0.0.0.0 '
else :
address = ' 127.0.0.1 '
2023-02-08 02:57:17 +00:00
port = 8188
try :
p_index = sys . argv . index ( ' --port ' )
port = int ( sys . argv [ p_index + 1 ] )
except :
pass
run ( q , address = address , port = port )
2023-01-03 06:53:32 +00:00