diff --git a/execution.py b/execution.py index bca48a78..7db1f095 100644 --- a/execution.py +++ b/execution.py @@ -700,10 +700,12 @@ class PromptQueue: self.server.queue_updated() self.not_empty.notify() - def get(self): + def get(self, timeout=None): with self.not_empty: while len(self.queue) == 0: - self.not_empty.wait() + self.not_empty.wait(timeout=timeout) + if timeout is not None and len(self.queue) == 0: + return None item = heapq.heappop(self.queue) i = self.task_counter self.currently_running[i] = copy.deepcopy(item) diff --git a/main.py b/main.py index 3997fbef..1f9c5f44 100644 --- a/main.py +++ b/main.py @@ -89,23 +89,36 @@ def cuda_malloc_warning(): def prompt_worker(q, server): e = execution.PromptExecutor(server) last_gc_collect = 0 - while True: - item, item_id = q.get() - execution_start_time = time.perf_counter() - prompt_id = item[1] - e.execute(item[2], prompt_id, item[3], item[4]) - q.task_done(item_id, e.outputs_ui) - if server.client_id is not None: - server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + need_gc = False + gc_collect_interval = 10.0 - current_time = time.perf_counter() - execution_time = current_time - execution_start_time - print("Prompt executed in {:.2f} seconds".format(execution_time)) - if (current_time - last_gc_collect) > 10.0: - gc.collect() - comfy.model_management.soft_empty_cache() - last_gc_collect = current_time - print("gc collect") + while True: + timeout = None + if need_gc: + timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) + + queue_item = q.get(timeout=timeout) + if queue_item is not None: + item, item_id = queue_item + execution_start_time = time.perf_counter() + prompt_id = item[1] + e.execute(item[2], prompt_id, item[3], item[4]) + need_gc = True + q.task_done(item_id, e.outputs_ui) + if server.client_id is not None: + server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + + current_time = time.perf_counter() + execution_time = current_time - execution_start_time + print("Prompt executed in {:.2f} seconds".format(execution_time)) + + if need_gc: + current_time = time.perf_counter() + if (current_time - last_gc_collect) > gc_collect_interval: + gc.collect() + comfy.model_management.soft_empty_cache() + last_gc_collect = current_time + need_gc = False async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())