Add more features to the backend queue code.
The queue can now be queried, entries can be deleted and prompts easily queued to the front of the queue. Just need to expose it in the UI next.
This commit is contained in:
parent
9d611a90e8
commit
4b08314257
91
main.py
91
main.py
|
@ -3,7 +3,7 @@ import sys
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
if '--dont-upcast-attention' in sys.argv:
|
if '--dont-upcast-attention' in sys.argv:
|
||||||
|
@ -148,6 +148,7 @@ class PromptExecutor:
|
||||||
to_execute += [(0, x)]
|
to_execute += [(0, x)]
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while len(to_execute) > 0:
|
||||||
|
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||||
x = to_execute.pop(0)[-1]
|
x = to_execute.pop(0)[-1]
|
||||||
|
|
||||||
|
@ -266,10 +267,63 @@ def validate_prompt(prompt):
|
||||||
def prompt_worker(q):
|
def prompt_worker(q):
|
||||||
e = PromptExecutor()
|
e = PromptExecutor()
|
||||||
while True:
|
while True:
|
||||||
item = q.get()
|
item, item_id = q.get()
|
||||||
e.execute(item[-2], item[-1])
|
e.execute(item[-2], item[-1])
|
||||||
q.task_done()
|
q.task_done(item_id)
|
||||||
|
|
||||||
|
class PromptQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self.mutex = threading.RLock()
|
||||||
|
self.not_empty = threading.Condition(self.mutex)
|
||||||
|
self.task_counter = 0
|
||||||
|
self.queue = []
|
||||||
|
self.currently_running = {}
|
||||||
|
|
||||||
|
def put(self, item):
|
||||||
|
with self.mutex:
|
||||||
|
heapq.heappush(self.queue, item)
|
||||||
|
self.not_empty.notify()
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self.not_empty:
|
||||||
|
while len(self.queue) == 0:
|
||||||
|
self.not_empty.wait()
|
||||||
|
item = heapq.heappop(self.queue)
|
||||||
|
i = self.task_counter
|
||||||
|
self.currently_running[i] = copy.deepcopy(item)
|
||||||
|
self.task_counter += 1
|
||||||
|
return (item, i)
|
||||||
|
|
||||||
|
def task_done(self, item_id):
|
||||||
|
with self.mutex:
|
||||||
|
self.currently_running.pop(item_id)
|
||||||
|
|
||||||
|
def get_current_queue(self):
|
||||||
|
with self.mutex:
|
||||||
|
out = []
|
||||||
|
for x in self.currently_running.values():
|
||||||
|
out += [x]
|
||||||
|
return (out, copy.deepcopy(self.queue))
|
||||||
|
|
||||||
|
def get_tasks_remaining(self):
|
||||||
|
with self.mutex:
|
||||||
|
return len(self.queue) + len(self.currently_running)
|
||||||
|
|
||||||
|
def wipe_queue(self):
|
||||||
|
with self.mutex:
|
||||||
|
self.queue = []
|
||||||
|
|
||||||
|
def delete_queue_item(self, function):
|
||||||
|
with self.mutex:
|
||||||
|
for x in range(len(self.queue)):
|
||||||
|
if function(self.queue[x]):
|
||||||
|
if len(self.queue) == 1:
|
||||||
|
self.wipe_queue()
|
||||||
|
else:
|
||||||
|
self.queue.pop(x)
|
||||||
|
heapq.heapify(self.queue)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
|
||||||
|
@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler):
|
||||||
self._set_headers(ct='application/json')
|
self._set_headers(ct='application/json')
|
||||||
prompt_info = {}
|
prompt_info = {}
|
||||||
exec_info = {}
|
exec_info = {}
|
||||||
exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks
|
exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining()
|
||||||
prompt_info['exec_info'] = exec_info
|
prompt_info['exec_info'] = exec_info
|
||||||
self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
|
self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
|
||||||
|
elif self.path == "/queue":
|
||||||
|
self._set_headers(ct='application/json')
|
||||||
|
queue_info = {}
|
||||||
|
current_queue = self.server.prompt_queue.get_current_queue()
|
||||||
|
queue_info['queue_running'] = current_queue[0]
|
||||||
|
queue_info['queue_pending'] = current_queue[1]
|
||||||
|
self.wfile.write(json.dumps(queue_info).encode('utf-8'))
|
||||||
elif self.path == "/object_info":
|
elif self.path == "/object_info":
|
||||||
self._set_headers(ct='application/json')
|
self._set_headers(ct='application/json')
|
||||||
out = {}
|
out = {}
|
||||||
|
@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler):
|
||||||
out_string = ""
|
out_string = ""
|
||||||
if self.path == "/prompt":
|
if self.path == "/prompt":
|
||||||
print("got prompt")
|
print("got prompt")
|
||||||
self.data_string = self.rfile.read(int(self.headers['Content-Length']))
|
data_string = self.rfile.read(int(self.headers['Content-Length']))
|
||||||
json_data = json.loads(self.data_string)
|
json_data = json.loads(data_string)
|
||||||
if "number" in json_data:
|
if "number" in json_data:
|
||||||
number = float(json_data['number'])
|
number = float(json_data['number'])
|
||||||
else:
|
else:
|
||||||
number = self.server.number
|
number = self.server.number
|
||||||
|
if "front" in json_data:
|
||||||
|
if json_data['front']:
|
||||||
|
number = -number
|
||||||
|
|
||||||
self.server.number += 1
|
self.server.number += 1
|
||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
|
@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler):
|
||||||
resp_code = 400
|
resp_code = 400
|
||||||
out_string = valid[1]
|
out_string = valid[1]
|
||||||
print("invalid prompt:", valid[1])
|
print("invalid prompt:", valid[1])
|
||||||
|
elif self.path == "/queue":
|
||||||
|
data_string = self.rfile.read(int(self.headers['Content-Length']))
|
||||||
|
json_data = json.loads(data_string)
|
||||||
|
if "clear" in json_data:
|
||||||
|
if json_data["clear"]:
|
||||||
|
self.server.prompt_queue.wipe_queue()
|
||||||
|
if "delete" in json_data:
|
||||||
|
to_delete = json_data['delete']
|
||||||
|
for id_to_delete in to_delete:
|
||||||
|
delete_func = lambda a: a[1] == int(id_to_delete)
|
||||||
|
self.server.prompt_queue.delete_queue_item(delete_func)
|
||||||
|
|
||||||
self._set_headers(code=resp_code)
|
self._set_headers(code=resp_code)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(out_string.encode('utf8'))
|
self.wfile.write(out_string.encode('utf8'))
|
||||||
|
@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
q = queue.PriorityQueue()
|
q = PromptQueue()
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
|
||||||
run(q, address='127.0.0.1', port=8188)
|
run(q, address='127.0.0.1', port=8188)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue