Add error, status to /history endpoint

This commit is contained in:
realazthat 2024-01-11 08:38:18 -05:00
parent 977eda19a6
commit 1b3d65bd84
2 changed files with 51 additions and 26 deletions

View File

@ -8,6 +8,7 @@ import heapq
import traceback import traceback
import gc import gc
import inspect import inspect
from typing import List, Literal, NamedTuple, Optional
import torch import torch
import nodes import nodes
@ -275,8 +276,15 @@ class PromptExecutor:
self.outputs = {} self.outputs = {}
self.object_storage = {} self.object_storage = {}
self.outputs_ui = {} self.outputs_ui = {}
self.status_notes = []
self.success = True
self.old_prompt = {} self.old_prompt = {}
def add_note(self, event, data, broadcast: bool):
self.status_notes.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"] node_id = error["node_id"]
class_type = prompt[node_id]["class_type"] class_type = prompt[node_id]["class_type"]
@ -290,9 +298,8 @@ class PromptExecutor:
"node_type": class_type, "node_type": class_type,
"executed": list(executed), "executed": list(executed),
} }
self.server.send_sync("execution_interrupted", mes, self.server.client_id) self.add_note("execution_interrupted", mes, broadcast=True)
else: else:
if self.server.client_id is not None:
mes = { mes = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": node_id, "node_id": node_id,
@ -305,7 +312,7 @@ class PromptExecutor:
"current_inputs": error["current_inputs"], "current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"], "current_outputs": error["current_outputs"],
} }
self.server.send_sync("execution_error", mes, self.server.client_id) self.add_note("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed # Next, remove the subsequent outputs since they will not be executed
to_delete = [] to_delete = []
@ -327,8 +334,7 @@ class PromptExecutor:
else: else:
self.server.client_id = None self.server.client_id = None
if self.server.client_id is not None: self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them #delete cached outputs if nodes don't exist for them
@ -361,8 +367,9 @@ class PromptExecutor:
del d del d
comfy.model_management.cleanup_models() comfy.model_management.cleanup_models()
if self.server.client_id is not None: self.add_note("execution_cached",
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) { "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
executed = set() executed = set()
output_node_id = None output_node_id = None
to_execute = [] to_execute = []
@ -378,8 +385,8 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True: if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break
@ -731,14 +738,27 @@ class PromptQueue:
self.server.queue_updated() self.server.queue_updated()
return (item, i) return (item, i)
def task_done(self, item_id, outputs): class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
notes: List[str]
def task_done(self, item_id, outputs,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history))) self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs: status_dict: dict|None = None
self.history[prompt[1]]["outputs"][o] = outputs[o] if status is not None:
status_dict = copy.deepcopy(status._asdict())
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
self.server.queue_updated() self.server.queue_updated()
def get_current_queue(self): def get_current_queue(self):

View File

@ -110,7 +110,12 @@ def prompt_worker(q, server):
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id,
e.outputs_ui,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
notes=e.status_notes))
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)