Only validate each input once.
This commit is contained in:
parent
02ca1c67f8
commit
d6dee8af1d
40
execution.py
40
execution.py
|
@ -147,7 +147,7 @@ class PromptExecutor:
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
def execute(self, prompt, extra_data={}):
|
def execute(self, prompt, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
|
@ -172,27 +172,15 @@ class PromptExecutor:
|
||||||
executed = set()
|
executed = set()
|
||||||
try:
|
try:
|
||||||
to_execute = []
|
to_execute = []
|
||||||
for x in prompt:
|
for x in list(execute_outputs):
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
to_execute += [(0, x)]
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
to_execute += [(0, x)]
|
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while len(to_execute) > 0:
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||||
x = to_execute.pop(0)[-1]
|
x = to_execute.pop(0)[-1]
|
||||||
|
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
if class_.OUTPUT_NODE == True:
|
|
||||||
valid = False
|
|
||||||
try:
|
|
||||||
m = validate_inputs(prompt, x)
|
|
||||||
valid = m[0]
|
|
||||||
except:
|
|
||||||
valid = False
|
|
||||||
if valid:
|
|
||||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||||
print("Processing interrupted")
|
print("Processing interrupted")
|
||||||
|
@ -219,8 +207,11 @@ class PromptExecutor:
|
||||||
comfy.model_management.soft_empty_cache()
|
comfy.model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item):
|
def validate_inputs(prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
|
if unique_id in validated:
|
||||||
|
return validated[unique_id]
|
||||||
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
@ -241,8 +232,9 @@ def validate_inputs(prompt, item):
|
||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
if r[val[1]] != type_input:
|
if r[val[1]] != type_input:
|
||||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||||
r = validate_inputs(prompt, o_id)
|
r = validate_inputs(prompt, o_id, validated)
|
||||||
if r[0] == False:
|
if r[0] == False:
|
||||||
|
validated[o_id] = r
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
if type_input == "INT":
|
if type_input == "INT":
|
||||||
|
@ -270,7 +262,10 @@ def validate_inputs(prompt, item):
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||||
return (True, "")
|
|
||||||
|
ret = (True, "")
|
||||||
|
validated[unique_id] = ret
|
||||||
|
return ret
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
|
@ -284,11 +279,12 @@ def validate_prompt(prompt):
|
||||||
|
|
||||||
good_outputs = set()
|
good_outputs = set()
|
||||||
errors = []
|
errors = []
|
||||||
|
validated = {}
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
valid = False
|
valid = False
|
||||||
reason = ""
|
reason = ""
|
||||||
try:
|
try:
|
||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o, validated)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reason = m[1]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -297,7 +293,7 @@ def validate_prompt(prompt):
|
||||||
reason = "Parsing error"
|
reason = "Parsing error"
|
||||||
|
|
||||||
if valid == True:
|
if valid == True:
|
||||||
good_outputs.add(x)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||||
print("output will be ignored")
|
print("output will be ignored")
|
||||||
|
@ -307,7 +303,7 @@ def validate_prompt(prompt):
|
||||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||||
|
|
||||||
return (True, "")
|
return (True, "", list(good_outputs))
|
||||||
|
|
||||||
|
|
||||||
class PromptQueue:
|
class PromptQueue:
|
||||||
|
|
2
main.py
2
main.py
|
@ -33,7 +33,7 @@ def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
item, item_id = q.get()
|
item, item_id = q.get()
|
||||||
e.execute(item[-2], item[-1])
|
e.execute(item[-3], item[-2], item[-1])
|
||||||
q.task_done(item_id, e.outputs)
|
q.task_done(item_id, e.outputs)
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
|
|
|
@ -312,7 +312,7 @@ class PromptServer():
|
||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2]))
|
||||||
else:
|
else:
|
||||||
resp_code = 400
|
resp_code = 400
|
||||||
out_string = valid[1]
|
out_string = valid[1]
|
||||||
|
|
Loading…
Reference in New Issue