Implement Dynamic Typing

This is a proof of concept to get feedback. Note that it requires the
frontend branch of the same name.
This commit is contained in:
Jacob Segal 2024-08-21 18:46:11 -07:00
parent 0dbba9f751
commit a063468444
4 changed files with 423 additions and 77 deletions

View File

@ -1,4 +1,5 @@
import nodes
from typing import Set, Tuple, Dict, List
from comfy_execution.graph_utils import is_link
@ -15,6 +16,7 @@ class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
self.node_definitions = DynamicNodeDefinitionCache(self)
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
@ -27,6 +29,9 @@ class DynamicPrompt:
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")
def get_node_definition(self, node_id):
return self.node_definitions.get_node_definition(node_id)
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
@ -54,8 +59,188 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
class DynamicNodeDefinitionCache:
def __init__(self, dynprompt: DynamicPrompt):
self.dynprompt = dynprompt
self.definitions = {}
self.inputs_from_output_slot = {}
self.inputs_from_output_node = {}
def get_node_definition(self, node_id):
if node_id not in self.definitions:
node = self.dynprompt.get_node(node_id)
if node is None:
return None
class_type = node["class_type"]
definition = node_class_info(class_type)
self.definitions[node_id] = definition
return self.definitions[node_id]
def get_constant_type(self, value):
if isinstance(value, (int, float)):
return "INT,FLOAT"
elif isinstance(value, str):
return "STRING"
elif isinstance(value, bool):
return "BOOL"
else:
return None
def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
node = self.dynprompt.get_node(node_id)
input_types: Dict[str, str] = {}
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
from_node_id, from_socket = input_data
if from_socket < len(self.definitions[from_node_id]["output_name"]):
input_types[input_name] = self.definitions[from_node_id]["output"][from_socket]
else:
input_types[input_name] = "*"
else:
constant_type = self.get_constant_type(input_data)
if constant_type is not None:
input_types[input_name] = constant_type
output_types: Dict[str, List[str]] = {}
for index in range(len(self.definitions[node_id]["output_name"])):
output_name = self.definitions[node_id]["output_name"][index]
if (node_id, index) not in self.inputs_from_output_slot:
continue
for (to_node_id, to_input_name) in self.inputs_from_output_slot[(node_id, index)]:
if output_name not in output_types:
output_types[output_name] = []
if to_input_name in self.definitions[to_node_id]["input"]["required"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["required"][to_input_name][0])
elif to_input_name in self.definitions[to_node_id]["input"]["optional"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["optional"][to_input_name][0])
else:
output_types[output_name].append("*")
return input_types, output_types
def resolve_dynamic_definitions(self, node_id_set: Set[str]):
entangled = {}
# Pre-fill with class info. Also, build a lookup table for output nodes
for node_id in node_id_set:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
self.definitions[node_id] = node_class_info(class_type)
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
input_tuple = tuple(input_data)
if input_tuple not in self.inputs_from_output_slot:
self.inputs_from_output_slot[input_tuple] = []
self.inputs_from_output_slot[input_tuple].append((node_id, input_name))
if input_tuple[0] not in self.inputs_from_output_node:
self.inputs_from_output_node[input_tuple[0]] = []
self.inputs_from_output_node[input_tuple[0]].append((node_id, input_name))
_, _, extra_info = get_input_info(self.definitions[node_id], input_name)
if extra_info is not None and extra_info.get("entangleTypes", False):
from_node_id = input_data[0]
if node_id not in entangled:
entangled[node_id] = []
if from_node_id not in entangled:
entangled[from_node_id] = []
entangled[node_id].append((from_node_id, input_name))
entangled[from_node_id].append((node_id, input_name))
# Evaluate node info
to_resolve = node_id_set.copy()
updated = {}
while len(to_resolve) > 0:
node_id = to_resolve.pop()
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, "resolve_dynamic_types"):
entangled_types = {}
for (entangled_id, entangled_name) in entangled.get(node_id, []):
entangled_def = self.get_node_definition(entangled_id)
if entangled_def is None:
continue
input_types = {}
output_types = {}
for input_category, input_list in entangled_def["input"].items():
for input_name, input_info in input_list.items():
if isinstance(input_info, tuple) or input_category != "hidden":
input_types[input_name] = input_info[0]
for i in range(len(entangled_def["output"])):
output_name = entangled_def["output_name"][i]
output_types[output_name] = entangled_def["output"][i]
if entangled_name not in entangled_types:
entangled_types[entangled_name] = []
entangled_types[entangled_name].append({
"node_id": entangled_id,
"input_types": input_types,
"output_types": output_types
})
input_types, output_types = self.get_input_output_types(node_id)
dynamic_info = class_def.resolve_dynamic_types(
input_types=input_types,
output_types=output_types,
entangled_types=entangled_types
)
old_info = self.definitions[node_id].copy()
self.definitions[node_id].update(dynamic_info)
updated[node_id] = self.definitions[node_id]
# We changed the info, so we potentially need to resolve adjacent and entangled nodes
if old_info != self.definitions[node_id]:
for (entangled_node_id, _) in entangled.get(node_id, []):
if entangled_node_id in node_id_set:
to_resolve.add(entangled_node_id)
for i in range(len(self.definitions[node_id]["output"])):
for (output_node_id, _) in self.inputs_from_output_slot.get((node_id, i), []):
if output_node_id in node_id_set:
to_resolve.add(output_node_id)
for _, input_data in node["inputs"].items():
if is_link(input_data):
if input_data[0] in node_id_set:
to_resolve.add(input_data[0])
for (to_node_id, _) in self.inputs_from_output_node.get(node_id, []):
if to_node_id in node_id_set:
to_resolve.add(to_node_id)
# Because this run may have changed the number of outputs, we may need to run it again
# in order to get those outputs passed as output_types.
to_resolve.add(node_id)
return updated
def node_class_info(node_class):
if node_class not in nodes.NODE_CLASS_MAPPINGS:
return None
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info
def get_input_info(node_info, input_name):
valid_inputs = node_info["input"]
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
@ -84,9 +269,7 @@ class TopologicalSort:
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
return get_input_info(self.dynprompt.get_node_definition(unique_id), input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
@ -197,11 +380,8 @@ class ExecutionList(TopologicalSort):
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
node_def = self.dynprompt.get_node_definition(node_id)
return node_def['output_node']
for node_id in node_list:
if is_output(node_id):

View File

@ -0,0 +1,173 @@
import re
from typing import Optional, Tuple
# This decorator can be used to enable a "template" syntax for types in a node.
#
# Dynamic Types
# When specifying a type for an input or output, you can wrap an arbitrary string in angle brackets to indicate that it is dynamic. For example, the type "<FOO>" will be the equivalent of "*" (with the commonly used hacks) with the caveat that all inputs/outputs with the same template name ("FOO" in this case) must have the same type. Use multiple different template names if you want to allow types to differ. Note that this only applies within a single instance of a node -- different nodes can have different type resolutions
#
# Wrapping Types
# Rather than using JUST a template type, you can also use a template type with a wrapping type. For example, if you have a node that takes two inputs with the types "<FOO>" and "Accumulation<FOO>" respectively, any output can be connected to the "<FOO>" input. Once that input has a value (let's say an IMAGE), the other input will resolve as well (to Accumulation<IMAGE> in this example).
#
# Variadic Inputs
# Sometimes, you want a node to take a dynamic number of inputs. To do this, create an input value that has a name followed by a number sign and a string (e.g. "input#COUNT"). This will cause additional inputs to be added and removed as the user attaches to those sockets. The string after the '#' can be used to ensure that you have the same number of sockets for two different inputs. For example, having inputs named "image#FOO" and "mask#BAR" will allow the number of images and the number of masks to dynamically increase independently. Having inputs named "image#FOO" and "mask#FOO" will ensure that there are the same number of images as masks.
#
# Variadic Input - Same Type
# If you want to have a variadic input with a dynamic type, you can combine the two. For example, if you have an input named "input#COUNT" with the type "<FOO>", you can attach multiple inputs to that socket. Once you attach a value to one of the inputs, all of the other inputs will resolve to the same type. This is useful for nodes that take a dynamic number of inputs of the same type.
#
# Variadic Input - Different Types
# If you want to have a variadic input with a dynamic type, you can combine the two. For example, if you have an input named "input#COUNT" with the type "<FOO#COUNT>", each socket for the input can have a different type. (Internally, this is equivalent to making the type <FOO1> where 1 is the index of this input.)
#
# Restrictions
# - All dynamic inputs must have `"forceInput": True` due to frontend reasons that will hopefully be resolved before merging.
def TemplateTypeSupport():
def decorator(cls):
old_input_types = getattr(cls, "INPUT_TYPES")
def new_input_types(cls):
old_types = old_input_types()
new_types = {
"required": {},
"optional": {},
"hidden": old_types.get("hidden", {}),
}
for category in ["required", "optional"]:
if category not in old_types:
continue
for key, value in old_types[category].items():
new_types[category][replace_variadic_suffix(key, 1)] = (template_to_type(value[0]),) + value[1:]
return new_types
setattr(cls, "INPUT_TYPES", classmethod(new_input_types))
old_outputs = getattr(cls, "RETURN_TYPES")
setattr(cls, "RETURN_TYPES", tuple(template_to_type(x) for x in old_outputs))
def resolve_dynamic_types(cls, input_types, output_types, entangled_types):
original_inputs = old_input_types()
# Step 1 - Find all variadic groups and determine their maximum used index
variadic_group_map = {}
max_group_index = {}
for category in ["required", "optional"]:
for key, value in original_inputs.get(category, {}).items():
root, group = determine_variadic_group(key)
if root is not None and group is not None:
variadic_group_map[root] = group
for type_map in [input_types, output_types]:
for key in type_map.keys():
root, index = determine_variadic_suffix(key)
if root is not None and index is not None:
if root in variadic_group_map:
group = variadic_group_map[root]
max_group_index[group] = max(max_group_index.get(group, 0), index)
# Step 2 - Create variadic arguments
variadic_inputs = {
"required": {},
"optional": {},
}
for category in ["required", "optional"]:
for key, value in original_inputs.get(category, {}).items():
root, group = determine_variadic_group(key)
if root is None or group is None:
# Copy it over as-is
variadic_inputs[category][key] = value
else:
for i in range(1, max_group_index.get(group, 0) + 2):
# Also replace any variadic suffixes in the type (for use with templates)
input_type = value[0]
if isinstance(input_type, str):
input_type = replace_variadic_suffix(input_type, i)
variadic_inputs[category][replace_variadic_suffix(key, i)] = (input_type,value[1])
# Step 3 - Resolve template arguments
resolved = {}
for category in ["required", "optional"]:
for key, value in variadic_inputs[category].items():
if key in input_types:
tkey, tvalue = determine_template_value(value[0], input_types[key])
if tkey is not None and tvalue is not None:
resolved[tkey] = type_intersection(resolved.get(tkey, "*"), tvalue)
for i in range(len(old_outputs)):
output_name = cls.RETURN_NAMES[i]
if output_name in output_types:
for output_type in output_types[output_name]:
tkey, tvalue = determine_template_value(old_outputs[i], output_type)
if tkey is not None and tvalue is not None:
resolved[tkey] = type_intersection(resolved.get(tkey, "*"), tvalue)
# Step 4 - Replace templates with resolved types
final_inputs = {
"required": {},
"optional": {},
"hidden": original_inputs.get("hidden", {}),
}
for category in ["required", "optional"]:
for key, value in variadic_inputs[category].items():
final_inputs[category][key] = (template_to_type(value[0], resolved),) + value[1:]
outputs = (template_to_type(x, resolved) for x in old_outputs)
return {
"input": final_inputs,
"output": tuple(outputs),
"output_name": cls.RETURN_NAMES,
"dynamic_counts": max_group_index,
}
setattr(cls, "resolve_dynamic_types", classmethod(resolve_dynamic_types))
return cls
return decorator
def type_intersection(a: str, b: str) -> str:
if a == "*":
return b
if b == "*":
return a
if a == b:
return a
aset = set(a.split(','))
bset = set(b.split(','))
intersection = aset.intersection(bset)
if len(intersection) == 0:
return "*"
return ",".join(intersection)
naked_template_regex = re.compile(r"^<(.+)>$")
qualified_template_regex = re.compile(r"^(.+)<(.+)>$")
variadic_template_regex = re.compile(r"([^<]+)#([^>]+)")
variadic_suffix_regex = re.compile(r"([^<]+)(\d+)")
empty_lookup = {}
def template_to_type(template, key_lookup=empty_lookup):
templ_match = naked_template_regex.match(template)
if templ_match:
return key_lookup.get(templ_match.group(1), "*")
templ_match = qualified_template_regex.match(template)
if templ_match:
resolved = key_lookup.get(templ_match.group(2), "*")
return qualified_template_regex.sub(r"\1<%s>" % resolved, template)
return template
# Returns the 'key' and 'value' of the template (if any)
def determine_template_value(template: str, actual_type: str) -> Tuple[Optional[str], Optional[str]]:
templ_match = naked_template_regex.match(template)
if templ_match:
return templ_match.group(1), actual_type
templ_match = qualified_template_regex.match(template)
actual_match = qualified_template_regex.match(actual_type)
if templ_match and actual_match and templ_match.group(1) == actual_match.group(1):
return templ_match.group(2), actual_match.group(2)
return None, None
def determine_variadic_group(template: str) -> Tuple[Optional[str], Optional[str]]:
variadic_match = variadic_template_regex.match(template)
if variadic_match:
return variadic_match.group(1), variadic_match.group(2)
return None, None
def replace_variadic_suffix(template: str, index: int) -> str:
return variadic_template_regex.sub(lambda match: match.group(1) + str(index), template)
def determine_variadic_suffix(template: str) -> Tuple[Optional[str], Optional[int]]:
variadic_match = variadic_suffix_regex.match(template)
if variadic_match:
return variadic_match.group(1), int(variadic_match.group(2))
return None, None

View File

@ -7,13 +7,13 @@ import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional, Dict, Tuple
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker, node_class_info
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
@ -37,8 +37,8 @@ class IsChangedCache:
return self.is_changed[node_id]
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[node["class_type"]]
node_def = self.dynprompt.get_node_definition(node_id)
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
@ -48,7 +48,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _ = get_input_data(node["inputs"], node_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
@ -87,13 +87,13 @@ class CacheSet:
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
def get_input_data(inputs, node_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = node_def['input']
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
input_type, input_category, input_info = get_input_info(node_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@ -126,6 +126,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "NODE_DEFINITION":
input_data_all[x] = [node_def]
return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please
@ -169,12 +171,12 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
process_inputs(input_dict, i)
return results
def merge_result_data(results, obj):
def merge_result_data(results, node_def):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
output_is_list = node_def['output_is_list']
if len(output_is_list) < len(results[0]):
output_is_list = output_is_list + [False] * (len(results[0]) - len(output_is_list))
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
@ -190,13 +192,14 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
def get_output_data(obj, node_def, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
num_outputs = len(node_def['output'])
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict):
@ -208,24 +211,24 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
new_graph = r['expand']
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
result = tuple([result] * num_outputs)
subgraph_results.append((new_graph, result))
elif 'result' in r:
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
result = tuple([result] * num_outputs)
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
r = tuple([r] * num_outputs)
results.append(r)
subgraph_results.append((None, r))
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
output = merge_result_data(results, node_def)
else:
output = []
ui = dict()
@ -249,6 +252,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
node_def = dynprompt.get_node_definition(unique_id)
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
@ -275,11 +279,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else:
resolved_output.append(r)
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_data = merge_result_data(resolved_outputs, node_def)
output_ui = []
has_subgraph = False
else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys = get_input_data(inputs, node_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -320,7 +324,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
output_data, output_ui, has_subgraph = get_output_data(obj, node_def, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
@ -351,10 +355,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
dynprompt.node_definitions.resolve_dynamic_definitions(set(new_graph.keys()))
# Figure out if the newly created node is an output node
for node_id, node_info in new_graph.items():
node_def = dynprompt.get_node_definition(node_id)
if node_def['output_node']:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
@ -470,6 +475,7 @@ class PromptExecutor:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
dynamic_prompt.node_definitions.resolve_dynamic_definitions(set(dynamic_prompt.all_node_ids()))
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
@ -528,7 +534,7 @@ class PromptExecutor:
def validate_inputs(prompt, item, validated):
def validate_inputs(dynprompt, prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
@ -536,8 +542,9 @@ def validate_inputs(prompt, item, validated):
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
node_def = dynprompt.get_node_definition(unique_id)
class_inputs = obj_class.INPUT_TYPES()
class_inputs = node_def['input']
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = []
@ -552,7 +559,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
type_input, input_category, extra_info = get_input_info(node_def, x)
assert extra_info is not None
if x not in inputs:
if input_category == "required":
@ -585,8 +592,9 @@ def validate_inputs(prompt, item, validated):
continue
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
o_node_def = dynprompt.get_node_definition(o_id)
r = o_node_def['output']
assert r is not None
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
@ -605,7 +613,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = validate_inputs(dynprompt, prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
@ -713,7 +721,7 @@ def validate_inputs(prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_data_all, _ = get_input_data(inputs, node_def, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@ -756,6 +764,8 @@ def full_type_name(klass):
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
dynprompt = DynamicPrompt(prompt)
dynprompt.node_definitions.resolve_dynamic_definitions(set(dynprompt.all_node_ids()))
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@ -768,8 +778,8 @@ def validate_prompt(prompt):
return (False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None:
node_def = dynprompt.get_node_definition(x)
if node_def is None:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because node {class_type} does not exist.",
@ -778,7 +788,7 @@ def validate_prompt(prompt):
}
return (False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if node_def['output_node']:
outputs.add(x)
if len(outputs) == 0:
@ -798,7 +808,7 @@ def validate_prompt(prompt):
valid = False
reasons = []
try:
m = validate_inputs(prompt, o, validated)
m = validate_inputs(dynprompt, prompt, o, validated)
valid = m[0]
reasons = m[1]
except Exception as ex:

View File

@ -32,6 +32,7 @@ from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
from comfy_execution.graph import DynamicPrompt, DynamicNodeDefinitionCache, node_class_info
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@ -525,43 +526,13 @@ class PromptServer():
async def get_prompt(request):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info
@routes.get("/object_info")
async def get_object_info(request):
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
out[x] = node_class_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
@ -572,7 +543,7 @@ class PromptServer():
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
out[node_class] = node_class_info(node_class)
return web.json_response(out)
@routes.get("/history")
@ -595,6 +566,18 @@ class PromptServer():
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/resolve_dynamic_types")
async def resolve_dynamic_types(request):
json_data = await request.json()
if 'prompt' not in json_data:
return web.json_response({"error": "no prompt"}, status=400)
prompt = json_data['prompt']
dynprompt = DynamicPrompt(prompt)
definitions = DynamicNodeDefinitionCache(dynprompt)
updated = definitions.resolve_dynamic_definitions(dynprompt.all_node_ids())
return web.json_response(updated)
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")