Fix a bug where cached outputs affected IS_CHANGED (#4535)
This change fixes a bug where non-constant values could be passed to the IS_CHANGED function. This would result in workflows taking an extra execution before they acted as if they were cached. The actual change is like 4 characters -- the rest is adding unit tests.
This commit is contained in:
parent
5f84ea63e8
commit
dafbe321d2
|
@ -47,7 +47,8 @@ class IsChangedCache:
|
||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
|
|
|
@ -459,3 +459,22 @@ class TestExecution:
|
||||||
assert len(images1) == 1, "Should have 1 image"
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
|
|
||||||
|
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||||
|
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||||
|
|
||||||
|
output = g.node("PreviewImage", images=test_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
|
@ -95,6 +95,31 @@ class TestCustomIsChanged:
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class TestIsChangedWithConstants:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "custom_is_changed"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def custom_is_changed(self, image, value):
|
||||||
|
return (image * value,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, image, value):
|
||||||
|
if image is None:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return image.mean().item() * value
|
||||||
|
|
||||||
class TestCustomValidation1:
|
class TestCustomValidation1:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
|
@ -312,6 +337,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
"TestCustomIsChanged": TestCustomIsChanged,
|
"TestCustomIsChanged": TestCustomIsChanged,
|
||||||
|
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||||
"TestCustomValidation1": TestCustomValidation1,
|
"TestCustomValidation1": TestCustomValidation1,
|
||||||
"TestCustomValidation2": TestCustomValidation2,
|
"TestCustomValidation2": TestCustomValidation2,
|
||||||
"TestCustomValidation3": TestCustomValidation3,
|
"TestCustomValidation3": TestCustomValidation3,
|
||||||
|
@ -325,6 +351,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"TestLazyMixImages": "Lazy Mix Images",
|
"TestLazyMixImages": "Lazy Mix Images",
|
||||||
"TestVariadicAverage": "Variadic Average",
|
"TestVariadicAverage": "Variadic Average",
|
||||||
"TestCustomIsChanged": "Custom IsChanged",
|
"TestCustomIsChanged": "Custom IsChanged",
|
||||||
|
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||||
"TestCustomValidation1": "Custom Validation 1",
|
"TestCustomValidation1": "Custom Validation 1",
|
||||||
"TestCustomValidation2": "Custom Validation 2",
|
"TestCustomValidation2": "Custom Validation 2",
|
||||||
"TestCustomValidation3": "Custom Validation 3",
|
"TestCustomValidation3": "Custom Validation 3",
|
||||||
|
|
|
@ -28,6 +28,28 @@ class StubImage:
|
||||||
elif content == "NOISE":
|
elif content == "NOISE":
|
||||||
return (torch.rand(batch_size, height, width, 3),)
|
return (torch.rand(batch_size, height, width, 3),)
|
||||||
|
|
||||||
|
class StubConstantImage:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stub_constant_image"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_constant_image(self, value, height, width, batch_size):
|
||||||
|
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||||
|
|
||||||
class StubMask:
|
class StubMask:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -93,12 +115,14 @@ class StubFloat:
|
||||||
|
|
||||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||||
"StubImage": StubImage,
|
"StubImage": StubImage,
|
||||||
|
"StubConstantImage": StubConstantImage,
|
||||||
"StubMask": StubMask,
|
"StubMask": StubMask,
|
||||||
"StubInt": StubInt,
|
"StubInt": StubInt,
|
||||||
"StubFloat": StubFloat,
|
"StubFloat": StubFloat,
|
||||||
}
|
}
|
||||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"StubImage": "Stub Image",
|
"StubImage": "Stub Image",
|
||||||
|
"StubConstantImage": "Stub Constant Image",
|
||||||
"StubMask": "Stub Mask",
|
"StubMask": "Stub Mask",
|
||||||
"StubInt": "Stub Int",
|
"StubInt": "Stub Int",
|
||||||
"StubFloat": "Stub Float",
|
"StubFloat": "Stub Float",
|
||||||
|
|
Loading…
Reference in New Issue