451 lines
20 KiB
Python
451 lines
20 KiB
Python
import nodes
|
|
from typing import Set, Tuple, Dict, List
|
|
|
|
from comfy_execution.graph_utils import is_link
|
|
|
|
class DependencyCycleError(Exception):
|
|
pass
|
|
|
|
class NodeInputError(Exception):
|
|
pass
|
|
|
|
class NodeNotFoundError(Exception):
|
|
pass
|
|
|
|
class DynamicPrompt:
|
|
def __init__(self, original_prompt):
|
|
# The original prompt provided by the user
|
|
self.original_prompt = original_prompt
|
|
self.node_definitions = DynamicNodeDefinitionCache(self)
|
|
# Any extra pieces of the graph created during execution
|
|
self.ephemeral_prompt = {}
|
|
self.ephemeral_parents = {}
|
|
self.ephemeral_display = {}
|
|
|
|
def get_node(self, node_id):
|
|
if node_id in self.ephemeral_prompt:
|
|
return self.ephemeral_prompt[node_id]
|
|
if node_id in self.original_prompt:
|
|
return self.original_prompt[node_id]
|
|
raise NodeNotFoundError(f"Node {node_id} not found")
|
|
|
|
def get_node_definition(self, node_id):
|
|
return self.node_definitions.get_node_definition(node_id)
|
|
|
|
def has_node(self, node_id):
|
|
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
|
|
|
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
|
self.ephemeral_prompt[node_id] = node_info
|
|
self.ephemeral_parents[node_id] = parent_id
|
|
self.ephemeral_display[node_id] = display_id
|
|
|
|
def get_real_node_id(self, node_id):
|
|
while node_id in self.ephemeral_parents:
|
|
node_id = self.ephemeral_parents[node_id]
|
|
return node_id
|
|
|
|
def get_parent_node_id(self, node_id):
|
|
return self.ephemeral_parents.get(node_id, None)
|
|
|
|
def get_display_node_id(self, node_id):
|
|
while node_id in self.ephemeral_display:
|
|
node_id = self.ephemeral_display[node_id]
|
|
return node_id
|
|
|
|
def all_node_ids(self):
|
|
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
|
|
|
def get_original_prompt(self):
|
|
return self.original_prompt
|
|
|
|
class DynamicNodeDefinitionCache:
|
|
def __init__(self, dynprompt: DynamicPrompt):
|
|
self.dynprompt = dynprompt
|
|
self.definitions = {}
|
|
self.inputs_from_output_slot = {}
|
|
self.inputs_from_output_node = {}
|
|
|
|
def get_node_definition(self, node_id):
|
|
if node_id not in self.definitions:
|
|
node = self.dynprompt.get_node(node_id)
|
|
if node is None:
|
|
return None
|
|
class_type = node["class_type"]
|
|
definition = node_class_info(class_type)
|
|
self.definitions[node_id] = definition
|
|
return self.definitions[node_id]
|
|
|
|
def get_constant_type(self, value):
|
|
if isinstance(value, (int, float)):
|
|
return "INT,FLOAT"
|
|
elif isinstance(value, str):
|
|
return "STRING"
|
|
elif isinstance(value, bool):
|
|
return "BOOL"
|
|
else:
|
|
return None
|
|
|
|
def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
|
|
node = self.dynprompt.get_node(node_id)
|
|
input_types: Dict[str, str] = {}
|
|
for input_name, input_data in node["inputs"].items():
|
|
if is_link(input_data):
|
|
from_node_id, from_socket = input_data
|
|
if from_socket < len(self.definitions[from_node_id]["output_name"]):
|
|
input_types[input_name] = self.definitions[from_node_id]["output"][from_socket]
|
|
else:
|
|
input_types[input_name] = "*"
|
|
else:
|
|
constant_type = self.get_constant_type(input_data)
|
|
if constant_type is not None:
|
|
input_types[input_name] = constant_type
|
|
output_types: Dict[str, List[str]] = {}
|
|
for index in range(len(self.definitions[node_id]["output_name"])):
|
|
output_name = self.definitions[node_id]["output_name"][index]
|
|
if (node_id, index) not in self.inputs_from_output_slot:
|
|
continue
|
|
for (to_node_id, to_input_name) in self.inputs_from_output_slot[(node_id, index)]:
|
|
if output_name not in output_types:
|
|
output_types[output_name] = []
|
|
if to_input_name in self.definitions[to_node_id]["input"]["required"]:
|
|
output_types[output_name].append(self.definitions[to_node_id]["input"]["required"][to_input_name][0])
|
|
elif to_input_name in self.definitions[to_node_id]["input"]["optional"]:
|
|
output_types[output_name].append(self.definitions[to_node_id]["input"]["optional"][to_input_name][0])
|
|
else:
|
|
output_types[output_name].append("*")
|
|
return input_types, output_types
|
|
|
|
def resolve_dynamic_definitions(self, node_id_set: Set[str]):
|
|
entangled = {}
|
|
# Pre-fill with class info. Also, build a lookup table for output nodes
|
|
for node_id in node_id_set:
|
|
node = self.dynprompt.get_node(node_id)
|
|
class_type = node["class_type"]
|
|
self.definitions[node_id] = node_class_info(class_type)
|
|
for input_name, input_data in node["inputs"].items():
|
|
if is_link(input_data):
|
|
input_tuple = tuple(input_data)
|
|
if input_tuple not in self.inputs_from_output_slot:
|
|
self.inputs_from_output_slot[input_tuple] = []
|
|
self.inputs_from_output_slot[input_tuple].append((node_id, input_name))
|
|
if input_tuple[0] not in self.inputs_from_output_node:
|
|
self.inputs_from_output_node[input_tuple[0]] = []
|
|
self.inputs_from_output_node[input_tuple[0]].append((node_id, input_name))
|
|
_, _, extra_info = get_input_info(self.definitions[node_id], input_name)
|
|
if extra_info is not None and extra_info.get("entangleTypes", False):
|
|
from_node_id = input_data[0]
|
|
if node_id not in entangled:
|
|
entangled[node_id] = []
|
|
if from_node_id not in entangled:
|
|
entangled[from_node_id] = []
|
|
|
|
entangled[node_id].append((from_node_id, input_name))
|
|
entangled[from_node_id].append((node_id, input_name))
|
|
|
|
# Evaluate node info
|
|
to_resolve = node_id_set.copy()
|
|
updated = {}
|
|
while len(to_resolve) > 0:
|
|
node_id = to_resolve.pop()
|
|
node = self.dynprompt.get_node(node_id)
|
|
class_type = node["class_type"]
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
if hasattr(class_def, "resolve_dynamic_types"):
|
|
entangled_types = {}
|
|
for (entangled_id, entangled_name) in entangled.get(node_id, []):
|
|
entangled_def = self.get_node_definition(entangled_id)
|
|
if entangled_def is None:
|
|
continue
|
|
input_types = {}
|
|
output_types = {}
|
|
for input_category, input_list in entangled_def["input"].items():
|
|
for input_name, input_info in input_list.items():
|
|
if isinstance(input_info, tuple) or input_category != "hidden":
|
|
input_types[input_name] = input_info[0]
|
|
for i in range(len(entangled_def["output"])):
|
|
output_name = entangled_def["output_name"][i]
|
|
output_types[output_name] = entangled_def["output"][i]
|
|
|
|
if entangled_name not in entangled_types:
|
|
entangled_types[entangled_name] = []
|
|
entangled_types[entangled_name].append({
|
|
"node_id": entangled_id,
|
|
"input_types": input_types,
|
|
"output_types": output_types
|
|
})
|
|
|
|
input_types, output_types = self.get_input_output_types(node_id)
|
|
dynamic_info = class_def.resolve_dynamic_types(
|
|
input_types=input_types,
|
|
output_types=output_types,
|
|
entangled_types=entangled_types
|
|
)
|
|
old_info = self.definitions[node_id].copy()
|
|
self.definitions[node_id].update(dynamic_info)
|
|
updated[node_id] = self.definitions[node_id]
|
|
# We changed the info, so we potentially need to resolve adjacent and entangled nodes
|
|
if old_info != self.definitions[node_id]:
|
|
for (entangled_node_id, _) in entangled.get(node_id, []):
|
|
if entangled_node_id in node_id_set:
|
|
to_resolve.add(entangled_node_id)
|
|
for i in range(len(self.definitions[node_id]["output"])):
|
|
for (output_node_id, _) in self.inputs_from_output_slot.get((node_id, i), []):
|
|
if output_node_id in node_id_set:
|
|
to_resolve.add(output_node_id)
|
|
for _, input_data in node["inputs"].items():
|
|
if is_link(input_data):
|
|
if input_data[0] in node_id_set:
|
|
to_resolve.add(input_data[0])
|
|
for (to_node_id, _) in self.inputs_from_output_node.get(node_id, []):
|
|
if to_node_id in node_id_set:
|
|
to_resolve.add(to_node_id)
|
|
# Because this run may have changed the number of outputs, we may need to run it again
|
|
# in order to get those outputs passed as output_types.
|
|
to_resolve.add(node_id)
|
|
return updated
|
|
|
|
def node_class_info(node_class):
|
|
if node_class not in nodes.NODE_CLASS_MAPPINGS:
|
|
return None
|
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
|
info = {}
|
|
info['input'] = obj_class.INPUT_TYPES()
|
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
|
info['output'] = obj_class.RETURN_TYPES
|
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
|
info['name'] = node_class
|
|
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
|
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
|
|
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
|
|
info['category'] = 'sd'
|
|
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
|
info['output_node'] = True
|
|
else:
|
|
info['output_node'] = False
|
|
|
|
if hasattr(obj_class, 'CATEGORY'):
|
|
info['category'] = obj_class.CATEGORY
|
|
|
|
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
|
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
|
|
|
if getattr(obj_class, "DEPRECATED", False):
|
|
info['deprecated'] = True
|
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
|
info['experimental'] = True
|
|
|
|
return info
|
|
|
|
|
|
def get_input_info(node_info, input_name):
|
|
valid_inputs = node_info["input"]
|
|
input_info = None
|
|
input_category = None
|
|
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
|
input_category = "required"
|
|
input_info = valid_inputs["required"][input_name]
|
|
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
|
input_category = "optional"
|
|
input_info = valid_inputs["optional"][input_name]
|
|
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
|
input_category = "hidden"
|
|
input_info = valid_inputs["hidden"][input_name]
|
|
if input_info is None:
|
|
return None, None, None
|
|
input_type = input_info[0]
|
|
if len(input_info) > 1:
|
|
extra_info = input_info[1]
|
|
else:
|
|
extra_info = {}
|
|
return input_type, input_category, extra_info
|
|
|
|
class TopologicalSort:
|
|
def __init__(self, dynprompt):
|
|
self.dynprompt = dynprompt
|
|
self.pendingNodes = {}
|
|
self.blockCount = {} # Number of nodes this node is directly blocked by
|
|
self.blocking = {} # Which nodes are blocked by this node
|
|
|
|
def get_input_info(self, unique_id, input_name):
|
|
return get_input_info(self.dynprompt.get_node_definition(unique_id), input_name)
|
|
|
|
def make_input_strong_link(self, to_node_id, to_input):
|
|
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
|
if to_input not in inputs:
|
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
|
value = inputs[to_input]
|
|
if not is_link(value):
|
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
|
from_node_id, from_socket = value
|
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
|
|
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
|
if not self.is_cached(from_node_id):
|
|
self.add_node(from_node_id)
|
|
if to_node_id not in self.blocking[from_node_id]:
|
|
self.blocking[from_node_id][to_node_id] = {}
|
|
self.blockCount[to_node_id] += 1
|
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
|
|
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
|
node_ids = [node_unique_id]
|
|
links = []
|
|
|
|
while len(node_ids) > 0:
|
|
unique_id = node_ids.pop()
|
|
if unique_id in self.pendingNodes:
|
|
continue
|
|
|
|
self.pendingNodes[unique_id] = True
|
|
self.blockCount[unique_id] = 0
|
|
self.blocking[unique_id] = {}
|
|
|
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
|
for input_name in inputs:
|
|
value = inputs[input_name]
|
|
if is_link(value):
|
|
from_node_id, from_socket = value
|
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
|
continue
|
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
|
node_ids.append(from_node_id)
|
|
links.append((from_node_id, from_socket, unique_id))
|
|
|
|
for link in links:
|
|
self.add_strong_link(*link)
|
|
|
|
def is_cached(self, node_id):
|
|
return False
|
|
|
|
def get_ready_nodes(self):
|
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
|
|
|
def pop_node(self, unique_id):
|
|
del self.pendingNodes[unique_id]
|
|
for blocked_node_id in self.blocking[unique_id]:
|
|
self.blockCount[blocked_node_id] -= 1
|
|
del self.blocking[unique_id]
|
|
|
|
def is_empty(self):
|
|
return len(self.pendingNodes) == 0
|
|
|
|
class ExecutionList(TopologicalSort):
|
|
"""
|
|
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
|
it can still be returned to the graph after having further dependencies added.
|
|
"""
|
|
def __init__(self, dynprompt, output_cache):
|
|
super().__init__(dynprompt)
|
|
self.output_cache = output_cache
|
|
self.staged_node_id = None
|
|
|
|
def is_cached(self, node_id):
|
|
return self.output_cache.get(node_id) is not None
|
|
|
|
def stage_node_execution(self):
|
|
assert self.staged_node_id is None
|
|
if self.is_empty():
|
|
return None, None, None
|
|
available = self.get_ready_nodes()
|
|
if len(available) == 0:
|
|
cycled_nodes = self.get_nodes_in_cycle()
|
|
# Because cycles composed entirely of static nodes are caught during initial validation,
|
|
# we will 'blame' the first node in the cycle that is not a static node.
|
|
blamed_node = cycled_nodes[0]
|
|
for node_id in cycled_nodes:
|
|
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
|
if display_node_id != node_id:
|
|
blamed_node = display_node_id
|
|
break
|
|
ex = DependencyCycleError("Dependency cycle detected")
|
|
error_details = {
|
|
"node_id": blamed_node,
|
|
"exception_message": str(ex),
|
|
"exception_type": "graph.DependencyCycleError",
|
|
"traceback": [],
|
|
"current_inputs": []
|
|
}
|
|
return None, error_details, ex
|
|
|
|
self.staged_node_id = self.ux_friendly_pick_node(available)
|
|
return self.staged_node_id, None, None
|
|
|
|
def ux_friendly_pick_node(self, node_list):
|
|
# If an output node is available, do that first.
|
|
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
|
# for a PreviewImage to display a result as soon as it can
|
|
# Some other heuristics could probably be used here to improve the UX further.
|
|
def is_output(node_id):
|
|
node_def = self.dynprompt.get_node_definition(node_id)
|
|
return node_def['output_node']
|
|
|
|
for node_id in node_list:
|
|
if is_output(node_id):
|
|
return node_id
|
|
|
|
#This should handle the VAEDecode -> preview case
|
|
for node_id in node_list:
|
|
for blocked_node_id in self.blocking[node_id]:
|
|
if is_output(blocked_node_id):
|
|
return node_id
|
|
|
|
#This should handle the VAELoader -> VAEDecode -> preview case
|
|
for node_id in node_list:
|
|
for blocked_node_id in self.blocking[node_id]:
|
|
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
|
if is_output(blocked_node_id1):
|
|
return node_id
|
|
|
|
#TODO: this function should be improved
|
|
return node_list[0]
|
|
|
|
def unstage_node_execution(self):
|
|
assert self.staged_node_id is not None
|
|
self.staged_node_id = None
|
|
|
|
def complete_node_execution(self):
|
|
node_id = self.staged_node_id
|
|
self.pop_node(node_id)
|
|
self.staged_node_id = None
|
|
|
|
def get_nodes_in_cycle(self):
|
|
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
|
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
|
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
|
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
|
for from_node_id in self.blocking:
|
|
for to_node_id in self.blocking[from_node_id]:
|
|
if True in self.blocking[from_node_id][to_node_id].values():
|
|
blocked_by[to_node_id][from_node_id] = True
|
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
|
while len(to_remove) > 0:
|
|
for node_id in to_remove:
|
|
for to_node_id in blocked_by:
|
|
if node_id in blocked_by[to_node_id]:
|
|
del blocked_by[to_node_id][node_id]
|
|
del blocked_by[node_id]
|
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
|
return list(blocked_by.keys())
|
|
|
|
class ExecutionBlocker:
|
|
"""
|
|
Return this from a node and any users will be blocked with the given error message.
|
|
If the message is None, execution will be blocked silently instead.
|
|
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
|
possible, a lazy input will be more efficient and have a better user experience.
|
|
This functionality is useful in two cases:
|
|
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
|
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
|
lazy evaluation to let it conditionally disable itself.)
|
|
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
|
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
|
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
|
"""
|
|
def __init__(self, message):
|
|
self.message = message
|
|
|