Add a way to interrupt current processing in the backend.
This commit is contained in:
parent
1e2c4df972
commit
69cc75fbf8
|
@ -162,3 +162,31 @@ def maximum_batch_area():
|
|||
memory_free = get_free_memory() / (1024 * 1024)
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
return int(max(area, 0))
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
||||
interrupt_processing = False
|
||||
def interrupt_current_processing(value=True):
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
interrupt_processing = value
|
||||
|
||||
def processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
return interrupt_processing
|
||||
|
||||
def throw_exception_if_processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
if interrupt_processing:
|
||||
interrupt_processing = False
|
||||
raise InterruptProcessingException()
|
||||
|
|
|
@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
|
|
|
@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
|||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id] and server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
|
|
7
nodes.py
7
nodes.py
|
@ -41,6 +41,13 @@ def recursive_search(directory):
|
|||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def interrupt_processing():
|
||||
model_management.interrupt_current_processing()
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
|
|
@ -140,7 +140,12 @@ class PromptServer():
|
|||
self.prompt_queue.delete_queue_item(delete_func)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@routes.post("/interrupt")
|
||||
async def post_interrupt(request):
|
||||
nodes.interrupt_processing()
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/history")
|
||||
async def post_history(request):
|
||||
json_data = await request.json()
|
||||
|
|
Loading…
Reference in New Issue