Add a way to interrupt current processing in the backend.
This commit is contained in:
parent
1e2c4df972
commit
69cc75fbf8
|
@ -162,3 +162,31 @@ def maximum_batch_area():
|
||||||
memory_free = get_free_memory() / (1024 * 1024)
|
memory_free = get_free_memory() / (1024 * 1024)
|
||||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||||
return int(max(area, 0))
|
return int(max(area, 0))
|
||||||
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class InterruptProcessingException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
interrupt_processing_mutex = threading.RLock()
|
||||||
|
|
||||||
|
interrupt_processing = False
|
||||||
|
def interrupt_current_processing(value=True):
|
||||||
|
global interrupt_processing
|
||||||
|
global interrupt_processing_mutex
|
||||||
|
with interrupt_processing_mutex:
|
||||||
|
interrupt_processing = value
|
||||||
|
|
||||||
|
def processing_interrupted():
|
||||||
|
global interrupt_processing
|
||||||
|
global interrupt_processing_mutex
|
||||||
|
with interrupt_processing_mutex:
|
||||||
|
return interrupt_processing
|
||||||
|
|
||||||
|
def throw_exception_if_processing_interrupted():
|
||||||
|
global interrupt_processing
|
||||||
|
global interrupt_processing_mutex
|
||||||
|
with interrupt_processing_mutex:
|
||||||
|
if interrupt_processing:
|
||||||
|
interrupt_processing = False
|
||||||
|
raise InterruptProcessingException()
|
||||||
|
|
|
@ -172,6 +172,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
if cond_or_uncond[o] == COND:
|
if cond_or_uncond[o] == COND:
|
||||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
|
|
|
@ -58,6 +58,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
|
|
||||||
|
nodes.before_node_execution()
|
||||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||||
if "ui" in outputs[unique_id] and server.client_id is not None:
|
if "ui" in outputs[unique_id] and server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||||
|
|
7
nodes.py
7
nodes.py
|
@ -41,6 +41,13 @@ def recursive_search(directory):
|
||||||
def filter_files_extensions(files, extensions):
|
def filter_files_extensions(files, extensions):
|
||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||||
|
|
||||||
|
|
||||||
|
def before_node_execution():
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
|
def interrupt_processing():
|
||||||
|
model_management.interrupt_current_processing()
|
||||||
|
|
||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|
|
@ -141,6 +141,11 @@ class PromptServer():
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
@routes.post("/interrupt")
|
||||||
|
async def post_interrupt(request):
|
||||||
|
nodes.interrupt_processing()
|
||||||
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/history")
|
@routes.post("/history")
|
||||||
async def post_history(request):
|
async def post_history(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
|
|
Loading…
Reference in New Issue