Merge branch 'enhanced-history-status' of https://github.com/realazthat/ComfyUI
This commit is contained in:
commit
1805cb2d69
46
execution.py
46
execution.py
|
@ -8,6 +8,7 @@ import heapq
|
||||||
import traceback
|
import traceback
|
||||||
import gc
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
@ -275,8 +276,15 @@ class PromptExecutor:
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
self.object_storage = {}
|
self.object_storage = {}
|
||||||
self.outputs_ui = {}
|
self.outputs_ui = {}
|
||||||
|
self.status_notes = []
|
||||||
|
self.success = True
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
|
|
||||||
|
def add_note(self, event, data, broadcast: bool):
|
||||||
|
self.status_notes.append((event, data))
|
||||||
|
if self.server.client_id is not None or broadcast:
|
||||||
|
self.server.send_sync(event, data, self.server.client_id)
|
||||||
|
|
||||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||||
node_id = error["node_id"]
|
node_id = error["node_id"]
|
||||||
class_type = prompt[node_id]["class_type"]
|
class_type = prompt[node_id]["class_type"]
|
||||||
|
@ -290,9 +298,8 @@ class PromptExecutor:
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
}
|
}
|
||||||
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
self.add_note("execution_interrupted", mes, broadcast=True)
|
||||||
else:
|
else:
|
||||||
if self.server.client_id is not None:
|
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
|
@ -305,7 +312,7 @@ class PromptExecutor:
|
||||||
"current_inputs": error["current_inputs"],
|
"current_inputs": error["current_inputs"],
|
||||||
"current_outputs": error["current_outputs"],
|
"current_outputs": error["current_outputs"],
|
||||||
}
|
}
|
||||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
self.add_note("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
# Next, remove the subsequent outputs since they will not be executed
|
# Next, remove the subsequent outputs since they will not be executed
|
||||||
to_delete = []
|
to_delete = []
|
||||||
|
@ -327,8 +334,7 @@ class PromptExecutor:
|
||||||
else:
|
else:
|
||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
if self.server.client_id is not None:
|
self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
#delete cached outputs if nodes don't exist for them
|
||||||
|
@ -361,8 +367,9 @@ class PromptExecutor:
|
||||||
del d
|
del d
|
||||||
|
|
||||||
comfy.model_management.cleanup_models()
|
comfy.model_management.cleanup_models()
|
||||||
if self.server.client_id is not None:
|
self.add_note("execution_cached",
|
||||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
||||||
|
broadcast=False)
|
||||||
executed = set()
|
executed = set()
|
||||||
output_node_id = None
|
output_node_id = None
|
||||||
to_execute = []
|
to_execute = []
|
||||||
|
@ -378,8 +385,8 @@ class PromptExecutor:
|
||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
# the actual SD code, instead it will report the node where the
|
# the actual SD code, instead it will report the node where the
|
||||||
# error was raised
|
# error was raised
|
||||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||||
if success is not True:
|
if self.success is not True:
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -731,14 +738,27 @@ class PromptQueue:
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
return (item, i)
|
return (item, i)
|
||||||
|
|
||||||
def task_done(self, item_id, outputs):
|
class ExecutionStatus(NamedTuple):
|
||||||
|
status_str: Literal['success', 'error']
|
||||||
|
completed: bool
|
||||||
|
notes: List[str]
|
||||||
|
|
||||||
|
def task_done(self, item_id, outputs,
|
||||||
|
status: Optional['PromptQueue.ExecutionStatus']):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||||
self.history.pop(next(iter(self.history)))
|
self.history.pop(next(iter(self.history)))
|
||||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
|
||||||
for o in outputs:
|
status_dict: dict|None = None
|
||||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
if status is not None:
|
||||||
|
status_dict = copy.deepcopy(status._asdict())
|
||||||
|
|
||||||
|
self.history[prompt[1]] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"outputs": copy.deepcopy(outputs),
|
||||||
|
'status': status_dict,
|
||||||
|
}
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
|
|
7
main.py
7
main.py
|
@ -110,7 +110,12 @@ def prompt_worker(q, server):
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id, e.outputs_ui)
|
q.task_done(item_id,
|
||||||
|
e.outputs_ui,
|
||||||
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
|
status_str='success' if e.success else 'error',
|
||||||
|
completed=e.success,
|
||||||
|
notes=e.status_notes))
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue