Add a /free route to unload models or free all memory.
A POST request to /free with: {"unload_models":true} will unload models from vram. A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow.
This commit is contained in:
parent
8c6493578b
commit
6d281b4ff4
20
execution.py
20
execution.py
|
@ -268,11 +268,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
|
self.server = server
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
self.object_storage = {}
|
self.object_storage = {}
|
||||||
self.outputs_ui = {}
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
|
||||||
|
|
||||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||||
node_id = error["node_id"]
|
node_id = error["node_id"]
|
||||||
|
@ -706,6 +709,7 @@ class PromptQueue:
|
||||||
self.queue = []
|
self.queue = []
|
||||||
self.currently_running = {}
|
self.currently_running = {}
|
||||||
self.history = {}
|
self.history = {}
|
||||||
|
self.flags = {}
|
||||||
server.prompt_queue = self
|
server.prompt_queue = self
|
||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
|
@ -792,3 +796,17 @@ class PromptQueue:
|
||||||
def delete_history_item(self, id_to_delete):
|
def delete_history_item(self, id_to_delete):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
self.history.pop(id_to_delete, None)
|
self.history.pop(id_to_delete, None)
|
||||||
|
|
||||||
|
def set_flag(self, name, data):
|
||||||
|
with self.mutex:
|
||||||
|
self.flags[name] = data
|
||||||
|
self.not_empty.notify()
|
||||||
|
|
||||||
|
def get_flags(self, reset=True):
|
||||||
|
with self.mutex:
|
||||||
|
if reset:
|
||||||
|
ret = self.flags
|
||||||
|
self.flags = {}
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
return self.flags.copy()
|
||||||
|
|
15
main.py
15
main.py
|
@ -97,7 +97,7 @@ def prompt_worker(q, server):
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
timeout = None
|
timeout = 1000.0
|
||||||
if need_gc:
|
if need_gc:
|
||||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||||
|
|
||||||
|
@ -118,6 +118,19 @@ def prompt_worker(q, server):
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
|
flags = q.get_flags()
|
||||||
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|
||||||
|
if flags.get("unload_models", free_memory):
|
||||||
|
comfy.model_management.unload_all_models()
|
||||||
|
need_gc = True
|
||||||
|
last_gc_collect = 0
|
||||||
|
|
||||||
|
if free_memory:
|
||||||
|
e.reset()
|
||||||
|
need_gc = True
|
||||||
|
last_gc_collect = 0
|
||||||
|
|
||||||
if need_gc:
|
if need_gc:
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||||
|
|
11
server.py
11
server.py
|
@ -507,6 +507,17 @@ class PromptServer():
|
||||||
nodes.interrupt_processing()
|
nodes.interrupt_processing()
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
@routes.post("/free")
|
||||||
|
async def post_interrupt(request):
|
||||||
|
json_data = await request.json()
|
||||||
|
unload_models = json_data.get("unload_models", False)
|
||||||
|
free_memory = json_data.get("free_memory", False)
|
||||||
|
if unload_models:
|
||||||
|
self.prompt_queue.set_flag("unload_models", unload_models)
|
||||||
|
if free_memory:
|
||||||
|
self.prompt_queue.set_flag("free_memory", free_memory)
|
||||||
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/history")
|
@routes.post("/history")
|
||||||
async def post_history(request):
|
async def post_history(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
|
|
Loading…
Reference in New Issue