From dafbe321d2dfbe2cabd9eb32e88f8088996cd524 Mon Sep 17 00:00:00 2001 From: guill Date: Wed, 21 Aug 2024 20:38:46 -0700 Subject: [PATCH] Fix a bug where cached outputs affected IS_CHANGED (#4535) This change fixes a bug where non-constant values could be passed to the IS_CHANGED function. This would result in workflows taking an extra execution before they acted as if they were cached. The actual change is like 4 characters -- the rest is adding unit tests. --- execution.py | 3 ++- tests/inference/test_execution.py | 19 +++++++++++++ .../testing-pack/specific_tests.py | 27 +++++++++++++++++++ .../testing_nodes/testing-pack/stubs.py | 24 +++++++++++++++++ 4 files changed, 72 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 9278af35..05a662cd 100644 --- a/execution.py +++ b/execution.py @@ -47,7 +47,8 @@ class IsChangedCache: self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 9df1d7df..7965165f 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -459,3 +459,22 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" + + # This tests that only constant outputs are used in the call to `IS_CHANGED` + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) + test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) + + output = g.node("PreviewImage", images=test_node.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" + assert not result.did_run(test_node), "The execution should have been cached" diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index b961d1b6..dd810023 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -95,6 +95,31 @@ class TestCustomIsChanged: else: return False +class TestIsChangedWithConstants: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_is_changed" + + CATEGORY = "Testing/Nodes" + + def custom_is_changed(self, image, value): + return (image * value,) + + @classmethod + def IS_CHANGED(cls, image, value): + if image is None: + return value + else: + return image.mean().item() * value + class TestCustomValidation1: @classmethod def INPUT_TYPES(cls): @@ -312,6 +337,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, + "TestIsChangedWithConstants": TestIsChangedWithConstants, "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, @@ -325,6 +351,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", + "TestIsChangedWithConstants": "IsChanged With Constants", "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py index 9be6eac9..a1df8752 100644 --- a/tests/inference/testing_nodes/testing-pack/stubs.py +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -28,6 +28,28 @@ class StubImage: elif content == "NOISE": return (torch.rand(batch_size, height, width, 3),) +class StubConstantImage: + def __init__(self): + pass + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_constant_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_constant_image(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width, 3) * value,) + class StubMask: def __init__(self): pass @@ -93,12 +115,14 @@ class StubFloat: TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, + "StubConstantImage": StubConstantImage, "StubMask": StubMask, "StubInt": StubInt, "StubFloat": StubFloat, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", + "StubConstantImage": "Stub Constant Image", "StubMask": "Stub Mask", "StubInt": "Stub Int", "StubFloat": "Stub Float",