Reduce repeated calls of INPUT_TYPES in cache (#4922)

This commit is contained in:
JettHu 2024-09-15 13:03:09 +08:00 committed by GitHub
parent ca08597670
commit 5e68a4ce67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 2 deletions

View File

@ -1,11 +1,21 @@
import itertools import itertools
from typing import Sequence, Mapping from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
import nodes import nodes
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet: class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {} self.keys = {}
@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values(): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):