From 6de6246dd47840df5a19cfe9590deb1c31011290 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 25 Feb 2023 18:36:29 -0500 Subject: [PATCH] Fix some potential issues related to threads. --- main.py | 12 ++++++++++++ server.py | 28 ++++++++++++++-------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index be5bb1ce..68d4f29f 100644 --- a/main.py +++ b/main.py @@ -371,6 +371,18 @@ class PromptQueue: return True return False + def get_history(self): + with self.mutex: + return copy.deepcopy(self.history) + + def wipe_history(self): + with self.mutex: + self.history = {} + + def delete_history_item(self, id_to_delete): + with self.mutex: + self.history.pop(id_to_delete, None) + async def run(server, address='', port=8188): await asyncio.gather(server.start(address, port), server.publish_loop()) diff --git a/server.py b/server.py index 2beb6cf8..cc7d4a9c 100644 --- a/server.py +++ b/server.py @@ -47,7 +47,7 @@ class PromptServer(): @routes.get("/") async def get_root(request): return web.FileResponse(os.path.join(self.web_root, "index.html")) - + @routes.get("/view/{file}") async def view_image(request): if "file" in request.match_info: @@ -59,11 +59,11 @@ class PromptServer(): return web.FileResponse(file) return web.Response(status=404) - + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) - + @routes.get("/object_info") async def get_object_info(request): out = {} @@ -79,11 +79,11 @@ class PromptServer(): info['category'] = obj_class.CATEGORY out[x] = info return web.json_response(out) - + @routes.get("/history") async def get_history(request): - return web.json_response(self.prompt_queue.history) - + return web.json_response(self.prompt_queue.get_history()) + @routes.get("/queue") async def get_queue(request): queue_info = {} @@ -91,7 +91,7 @@ class PromptServer(): queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) - + @routes.post("/prompt") async def post_prompt(request): print("got prompt") @@ -146,14 +146,14 @@ class PromptServer(): json_data = await request.json() if "clear" in json_data: if json_data["clear"]: - self.prompt_queue.history = {} + self.prompt_queue.wipe_history() if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - self.prompt_queue.history.pop(id_to_delete, None) - + self.prompt_queue.delete_history_item(id_to_delete) + return web.Response(status=200) - + self.app.add_routes(routes) self.app.add_routes([ web.static('/', self.web_root), @@ -181,7 +181,7 @@ class PromptServer(): def send_sync(self, event, data, sid=None): self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) - + def queue_updated(self): self.send_sync("status", { "status": self.get_queue_info() }) @@ -195,8 +195,8 @@ class PromptServer(): await runner.setup() site = web.TCPSite(runner, address, port) await site.start() - + if address == '': address = '0.0.0.0' print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address, port)) \ No newline at end of file + print("To see the GUI go to: http://{}:{}".format(address, port))