Add a way to interrupt current processing in the backend.

This commit is contained in:
comfyanonymous 2023-03-02 14:42:03 -05:00
parent 1e2c4df972
commit 69cc75fbf8
5 changed files with 44 additions and 1 deletions

View File

@ -162,3 +162,31 @@ def maximum_batch_area():
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
area = ((memory_free - 1024) * 0.9) / (0.6) area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0)) return int(max(area, 0))
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
def interrupt_current_processing(value=True):
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
interrupt_processing = value
def processing_interrupted():
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
return interrupt_processing
def throw_exception_if_processing_interrupted():
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()

View File

@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
model_management.throw_exception_if_processing_interrupted()
for o in range(batch_chunks): for o in range(batch_chunks):
if cond_or_uncond[o] == COND: if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]

View File

@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
server.send_sync("executing", { "node": unique_id }, server.client_id) server.send_sync("executing", { "node": unique_id }, server.client_id)
obj = class_def() obj = class_def()
nodes.before_node_execution()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
if "ui" in outputs[unique_id] and server.client_id is not None: if "ui" in outputs[unique_id] and server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)

View File

@ -41,6 +41,13 @@ def recursive_search(directory):
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
def before_node_execution():
model_management.throw_exception_if_processing_interrupted()
def interrupt_processing():
model_management.interrupt_current_processing()
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -140,7 +140,12 @@ class PromptServer():
self.prompt_queue.delete_queue_item(delete_func) self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200) return web.Response(status=200)
@routes.post("/interrupt")
async def post_interrupt(request):
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
json_data = await request.json() json_data = await request.json()