Add the prompt_id to some websocket messages.

This commit is contained in:
comfyanonymous 2023-05-11 01:22:40 -04:00
parent 974958ff81
commit dfc74c19d9
2 changed files with 5 additions and 5 deletions

View File

@ -147,7 +147,7 @@ class PromptExecutor:
self.old_prompt = {} self.old_prompt = {}
self.server = server self.server = server
def execute(self, prompt, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
@ -170,7 +170,7 @@ class PromptExecutor:
current_outputs = set(self.outputs.keys()) current_outputs = set(self.outputs.keys())
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set() executed = set()
try: try:
to_execute = [] to_execute = []
@ -190,7 +190,7 @@ class PromptExecutor:
message = str(traceback.format_exc()) message = str(traceback.format_exc())
print(message) print(message)
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message }, self.server.client_id) self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
to_delete = [] to_delete = []
for o in self.outputs: for o in self.outputs:
@ -207,7 +207,7 @@ class PromptExecutor:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None self.server.last_node_id = None
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id) self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()

View File

@ -33,7 +33,7 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
item, item_id = q.get() item, item_id = q.get()
e.execute(item[-3], item[-2], item[-1]) e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs) q.task_done(item_id, e.outputs)
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):