Fix some potential issues related to threads.

This commit is contained in:
comfyanonymous 2023-02-25 18:36:29 -05:00
parent 8074a58a1a
commit 6de6246dd4
2 changed files with 26 additions and 14 deletions

12
main.py
View File

@ -371,6 +371,18 @@ class PromptQueue:
return True return True
return False return False
def get_history(self):
with self.mutex:
return copy.deepcopy(self.history)
def wipe_history(self):
with self.mutex:
self.history = {}
def delete_history_item(self, id_to_delete):
with self.mutex:
self.history.pop(id_to_delete, None)
async def run(server, address='', port=8188): async def run(server, address='', port=8188):
await asyncio.gather(server.start(address, port), server.publish_loop()) await asyncio.gather(server.start(address, port), server.publish_loop())

View File

@ -47,7 +47,7 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html")) return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/view/{file}") @routes.get("/view/{file}")
async def view_image(request): async def view_image(request):
if "file" in request.match_info: if "file" in request.match_info:
@ -59,11 +59,11 @@ class PromptServer():
return web.FileResponse(file) return web.FileResponse(file)
return web.Response(status=404) return web.Response(status=404)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} out = {}
@ -79,11 +79,11 @@ class PromptServer():
info['category'] = obj_class.CATEGORY info['category'] = obj_class.CATEGORY
out[x] = info out[x] = info
return web.json_response(out) return web.json_response(out)
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
return web.json_response(self.prompt_queue.history) return web.json_response(self.prompt_queue.get_history())
@routes.get("/queue") @routes.get("/queue")
async def get_queue(request): async def get_queue(request):
queue_info = {} queue_info = {}
@ -91,7 +91,7 @@ class PromptServer():
queue_info['queue_running'] = current_queue[0] queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1] queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info) return web.json_response(queue_info)
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
print("got prompt") print("got prompt")
@ -146,14 +146,14 @@ class PromptServer():
json_data = await request.json() json_data = await request.json()
if "clear" in json_data: if "clear" in json_data:
if json_data["clear"]: if json_data["clear"]:
self.prompt_queue.history = {} self.prompt_queue.wipe_history()
if "delete" in json_data: if "delete" in json_data:
to_delete = json_data['delete'] to_delete = json_data['delete']
for id_to_delete in to_delete: for id_to_delete in to_delete:
self.prompt_queue.history.pop(id_to_delete, None) self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200) return web.Response(status=200)
self.app.add_routes(routes) self.app.add_routes(routes)
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root), web.static('/', self.web_root),
@ -181,7 +181,7 @@ class PromptServer():
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)) self.messages.put_nowait, (event, data, sid))
def queue_updated(self): def queue_updated(self):
self.send_sync("status", { "status": self.get_queue_info() }) self.send_sync("status", { "status": self.get_queue_info() })
@ -195,8 +195,8 @@ class PromptServer():
await runner.setup() await runner.setup()
site = web.TCPSite(runner, address, port) site = web.TCPSite(runner, address, port)
await site.start() await site.start()
if address == '': if address == '':
address = '0.0.0.0' address = '0.0.0.0'
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) print("To see the GUI go to: http://{}:{}".format(address, port))