From 6ab1e6fd4a2f7cc5945310f0ecfc11617aa9a2cb Mon Sep 17 00:00:00 2001 From: guill Date: Sat, 24 Aug 2024 12:34:58 -0700 Subject: [PATCH] [Bug #4529] Fix graph partial validation failure (#4588) Currently, if a graph partially fails validation (i.e. some outputs are valid while others have links from missing nodes), the execution loop could get an exception resulting in server lockup. This isn't actually possible to reproduce via the default UI, but is a potential issue for people using the API to construct invalid graphs. --- comfy_execution/caching.py | 9 +++++++++ tests/inference/test_execution.py | 23 +++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 6664a342..e67914a3 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -56,6 +56,8 @@ class CacheKeySetID(CacheKeySet): for node_id in node_ids: if node_id in self.keys: continue + if not self.dynprompt.has_node(node_id): + continue node = self.dynprompt.get_node(node_id) self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) @@ -74,6 +76,8 @@ class CacheKeySetInputSignature(CacheKeySet): for node_id in node_ids: if node_id in self.keys: continue + if not self.dynprompt.has_node(node_id): + continue node = self.dynprompt.get_node(node_id) self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) self.subcache_keys[node_id] = (node_id, node["class_type"]) @@ -87,6 +91,9 @@ class CacheKeySetInputSignature(CacheKeySet): return to_hashable(signature) def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + if not dynprompt.has_node(node_id): + # This node doesn't exist -- we can't cache it. + return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -112,6 +119,8 @@ class CacheKeySetInputSignature(CacheKeySet): return ancestors, order_mapping def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + if not dynprompt.has_node(node_id): + return inputs = dynprompt.get_node(node_id)["inputs"] input_keys = sorted(inputs.keys()) for key in input_keys: diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 7965165f..ffc0c482 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -357,6 +357,25 @@ class TestExecution: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0)) + # We have multiple outputs. The first is invalid, but the second is valid + g.node("SaveImage", images=mix1.out(0)) + g.node("SaveImage", images=mix2.out(0)) + g.remove_node("removeme") + + client.run(g) + + # Add back in the missing node to make sure the error doesn't break the server + input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) + client.run(g) + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug @@ -450,8 +469,8 @@ class TestExecution: g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) - output1 = g.node("PreviewImage", images=input1.out(0)) - output2 = g.node("PreviewImage", images=input1.out(0)) + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input1.out(0)) result = client.run(g) images1 = result.get_images(output1)