diff --git a/execution.py b/execution.py index 632aaa84..5ed9ff34 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +258,44 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "node_id": error["node_id"], + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "message": error["message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "node_id": error["node_id"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +330,29 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", @@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] @@ -507,13 +580,13 @@ def validate_prompt(prompt): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }]