Persist node instances between executions instead of deleting them.

If the same node id with the same class exists between two executions the
same instance will be used.

This means you can now cache things in nodes for more efficiency.
This commit is contained in:
comfyanonymous 2023-06-29 23:38:56 -04:00
parent 9920367d3c
commit 6e9f28401f
1 changed files with 20 additions and 4 deletions

View File

@ -110,7 +110,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True: if result[0] is not True:
# Another node failed further upstream # Another node failed further upstream
return result return result
@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = unique_id server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def() obj = class_def()
object_storage[(unique_id, class_type)] = obj
output_data, output_ui = get_output_data(obj, input_data_all) output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data outputs[unique_id] = output_data
@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server):
self.outputs = {} self.outputs = {}
self.object_storage = {}
self.outputs_ui = {} self.outputs_ui = {}
self.old_prompt = {} self.old_prompt = {}
self.server = server self.server = server
@ -322,6 +327,17 @@ class PromptExecutor:
for o in to_delete: for o in to_delete:
d = self.outputs.pop(o) d = self.outputs.pop(o)
del d del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
for x in prompt: for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
@ -349,7 +365,7 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True: if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break