From d6dee8af1df5e7dc80463b9e45bdce76767e4119 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 00:29:31 -0400 Subject: [PATCH] Only validate each input once. --- execution.py | 40 ++++++++++++++++++---------------------- main.py | 2 +- server.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/execution.py b/execution.py index edf88461..3953fde3 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -172,27 +172,15 @@ class PromptExecutor: executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -219,8 +207,11 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -241,8 +232,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -270,7 +262,10 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -284,11 +279,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -297,7 +293,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -307,7 +303,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: diff --git a/main.py b/main.py index eb97a2fb..d385df70 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) + e.execute(item[-3], item[-2], item[-1]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index c1226f30..b6ac7d48 100644 --- a/server.py +++ b/server.py @@ -312,7 +312,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) else: resp_code = 400 out_string = valid[1]