diff --git a/comfy/model_management.py b/comfy/model_management.py index 2ec6bbea..36f925c4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -162,3 +162,31 @@ def maximum_batch_area(): memory_free = get_free_memory() / (1024 * 1024) area = ((memory_free - 1024) * 0.9) / (0.6) return int(max(area, 0)) +#TODO: might be cleaner to put this somewhere else +import threading + +class InterruptProcessingException(Exception): + pass + +interrupt_processing_mutex = threading.RLock() + +interrupt_processing = False +def interrupt_current_processing(value=True): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + interrupt_processing = value + +def processing_interrupted(): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + return interrupt_processing + +def throw_exception_if_processing_interrupted(): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + if interrupt_processing: + interrupt_processing = False + raise InterruptProcessingException() diff --git a/comfy/samplers.py b/comfy/samplers.py index 14b927b6..3562f89d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x + model_management.throw_exception_if_processing_interrupted() + for o in range(batch_chunks): if cond_or_uncond[o] == COND: out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] diff --git a/execution.py b/execution.py index 128c5638..b0f4f952 100644 --- a/execution.py +++ b/execution.py @@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() + nodes.before_node_execution() outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) if "ui" in outputs[unique_id] and server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) diff --git a/nodes.py b/nodes.py index 1b22e18d..fe24a6cd 100644 --- a/nodes.py +++ b/nodes.py @@ -41,6 +41,13 @@ def recursive_search(directory): def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) + +def before_node_execution(): + model_management.throw_exception_if_processing_interrupted() + +def interrupt_processing(): + model_management.interrupt_current_processing() + class CLIPTextEncode: @classmethod def INPUT_TYPES(s): diff --git a/server.py b/server.py index 7af5302b..307352fe 100644 --- a/server.py +++ b/server.py @@ -140,7 +140,12 @@ class PromptServer(): self.prompt_queue.delete_queue_item(delete_func) return web.Response(status=200) - + + @routes.post("/interrupt") + async def post_interrupt(request): + nodes.interrupt_processing() + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json()