Add more features to the backend queue code.

The queue can now be queried, entries can be deleted and prompts easily
queued to the front of the queue.

Just need to expose it in the UI next.
This commit is contained in:
comfyanonymous 2023-02-01 22:33:10 -05:00
parent 9d611a90e8
commit 4b08314257
1 changed files with 84 additions and 7 deletions

91
main.py
View File

@ -3,7 +3,7 @@ import sys
import copy import copy
import json import json
import threading import threading
import queue import heapq
import traceback import traceback
if '--dont-upcast-attention' in sys.argv: if '--dont-upcast-attention' in sys.argv:
@ -148,6 +148,7 @@ class PromptExecutor:
to_execute += [(0, x)] to_execute += [(0, x)]
while len(to_execute) > 0: while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1] x = to_execute.pop(0)[-1]
@ -266,10 +267,63 @@ def validate_prompt(prompt):
def prompt_worker(q): def prompt_worker(q):
e = PromptExecutor() e = PromptExecutor()
while True: while True:
item = q.get() item, item_id = q.get()
e.execute(item[-2], item[-1]) e.execute(item[-2], item[-1])
q.task_done() q.task_done(item_id)
class PromptQueue:
def __init__(self):
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.not_empty.notify()
def get(self):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
return (item, i)
def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
def get_current_queue(self):
with self.mutex:
out = []
for x in self.currently_running.values():
out += [x]
return (out, copy.deepcopy(self.queue))
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)
def wipe_queue(self):
with self.mutex:
self.queue = []
def delete_queue_item(self, function):
with self.mutex:
for x in range(len(self.queue)):
if function(self.queue[x]):
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
heapq.heapify(self.queue)
return True
return False
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler):
self._set_headers(ct='application/json') self._set_headers(ct='application/json')
prompt_info = {} prompt_info = {}
exec_info = {} exec_info = {}
exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info prompt_info['exec_info'] = exec_info
self.wfile.write(json.dumps(prompt_info).encode('utf-8')) self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
elif self.path == "/queue":
self._set_headers(ct='application/json')
queue_info = {}
current_queue = self.server.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
self.wfile.write(json.dumps(queue_info).encode('utf-8'))
elif self.path == "/object_info": elif self.path == "/object_info":
self._set_headers(ct='application/json') self._set_headers(ct='application/json')
out = {} out = {}
@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler):
out_string = "" out_string = ""
if self.path == "/prompt": if self.path == "/prompt":
print("got prompt") print("got prompt")
self.data_string = self.rfile.read(int(self.headers['Content-Length'])) data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(self.data_string) json_data = json.loads(data_string)
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
else: else:
number = self.server.number number = self.server.number
if "front" in json_data:
if json_data['front']:
number = -number
self.server.number += 1 self.server.number += 1
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler):
resp_code = 400 resp_code = 400
out_string = valid[1] out_string = valid[1]
print("invalid prompt:", valid[1]) print("invalid prompt:", valid[1])
elif self.path == "/queue":
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)
if "clear" in json_data:
if json_data["clear"]:
self.server.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.server.prompt_queue.delete_queue_item(delete_func)
self._set_headers(code=resp_code) self._set_headers(code=resp_code)
self.end_headers() self.end_headers()
self.wfile.write(out_string.encode('utf8')) self.wfile.write(out_string.encode('utf8'))
@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188):
if __name__ == "__main__": if __name__ == "__main__":
q = queue.PriorityQueue() q = PromptQueue()
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
run(q, address='127.0.0.1', port=8188) run(q, address='127.0.0.1', port=8188)