Only validate each input once.

This commit is contained in:
comfyanonymous 2023-05-10 00:29:31 -04:00
parent 02ca1c67f8
commit d6dee8af1d
3 changed files with 20 additions and 24 deletions

View File

@ -147,7 +147,7 @@ class PromptExecutor:
self.old_prompt = {} self.old_prompt = {}
self.server = server self.server = server
def execute(self, prompt, extra_data={}): def execute(self, prompt, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
@ -172,27 +172,15 @@ class PromptExecutor:
executed = set() executed = set()
try: try:
to_execute = [] to_execute = []
for x in prompt: for x in list(execute_outputs):
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] to_execute += [(0, x)]
if hasattr(class_, 'OUTPUT_NODE'):
to_execute += [(0, x)]
while len(to_execute) > 0: while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first #always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1] x = to_execute.pop(0)[-1]
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
if hasattr(class_, 'OUTPUT_NODE'):
if class_.OUTPUT_NODE == True:
valid = False
try:
m = validate_inputs(prompt, x)
valid = m[0]
except:
valid = False
if valid:
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
except Exception as e: except Exception as e:
if isinstance(e, comfy.model_management.InterruptProcessingException): if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted") print("Processing interrupted")
@ -219,8 +207,11 @@ class PromptExecutor:
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item): def validate_inputs(prompt, item, validated):
unique_id = item unique_id = item
if unique_id in validated:
return validated[unique_id]
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -241,8 +232,9 @@ def validate_inputs(prompt, item):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input: if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
r = validate_inputs(prompt, o_id) r = validate_inputs(prompt, o_id, validated)
if r[0] == False: if r[0] == False:
validated[o_id] = r
return r return r
else: else:
if type_input == "INT": if type_input == "INT":
@ -270,7 +262,10 @@ def validate_inputs(prompt, item):
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")
ret = (True, "")
validated[unique_id] = ret
return ret
def validate_prompt(prompt): def validate_prompt(prompt):
outputs = set() outputs = set()
@ -284,11 +279,12 @@ def validate_prompt(prompt):
good_outputs = set() good_outputs = set()
errors = [] errors = []
validated = {}
for o in outputs: for o in outputs:
valid = False valid = False
reason = "" reason = ""
try: try:
m = validate_inputs(prompt, o) m = validate_inputs(prompt, o, validated)
valid = m[0] valid = m[0]
reason = m[1] reason = m[1]
except Exception as e: except Exception as e:
@ -297,7 +293,7 @@ def validate_prompt(prompt):
reason = "Parsing error" reason = "Parsing error"
if valid == True: if valid == True:
good_outputs.add(x) good_outputs.add(o)
else: else:
print("Failed to validate prompt for output {} {}".format(o, reason)) print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored") print("output will be ignored")
@ -307,7 +303,7 @@ def validate_prompt(prompt):
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
return (True, "") return (True, "", list(good_outputs))
class PromptQueue: class PromptQueue:

View File

@ -33,7 +33,7 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
item, item_id = q.get() item, item_id = q.get()
e.execute(item[-2], item[-1]) e.execute(item[-3], item[-2], item[-1])
q.task_done(item_id, e.outputs) q.task_done(item_id, e.outputs)
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):

View File

@ -312,7 +312,7 @@ class PromptServer():
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
if valid[0]: if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data)) self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2]))
else: else:
resp_code = 400 resp_code = 400
out_string = valid[1] out_string = valid[1]