From ccad603b2e6862a4a719bc34dc6bd32e65a539ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 16:03:26 -0400 Subject: [PATCH] Add a way for nodes to validate their own inputs. --- execution.py | 21 +++++++++++---------- folder_paths.py | 6 +++--- nodes.py | 32 +++++++++++++++++++++++--------- web/scripts/app.js | 2 +- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/execution.py b/execution.py index b062deeb..115efcbd 100644 --- a/execution.py +++ b/execution.py @@ -11,7 +11,6 @@ import torch import nodes import comfy.model_management -import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -250,14 +249,15 @@ def validate_inputs(prompt, item): if "max" in info[1] and val > info[1]["max"]: return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") - if is_annotated_path: - if not folder_paths.exists_annotated_filepath(val): - return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) - - elif val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + ret = obj_class.VALIDATE_INPUTS(**input_data_all) + if ret != True: + return (False, "{}, {}".format(class_type, ret)) + else: + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): @@ -279,7 +279,8 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o) valid = m[0] reason = m[1] - except: + except Exception as e: + print(traceback.format_exc()) valid = False reason = "Parsing error" diff --git a/folder_paths.py b/folder_paths.py index 99a01669..e5b89492 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -71,7 +71,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def touch_annotated_filepath(name): +def annotated_filepath(name): if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -88,7 +88,7 @@ def touch_annotated_filepath(name): def get_annotated_filepath(name, default_dir=None): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: if default_dir is not None: @@ -100,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None): def exists_annotated_filepath(name): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: base_dir = get_input_directory() # fallback path diff --git a/nodes.py b/nodes.py index b8b6280d..d1133d1d 100644 --- a/nodes.py +++ b/nodes.py @@ -974,8 +974,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -989,20 +988,27 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + return True + class LoadImageMask: + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() return {"required": {"image": (sorted(os.listdir(input_dir)), ), - "channel": (["alpha", "red", "green", "blue"], ),} + "channel": (s._color_channels, ),} } CATEGORY = "mask" @@ -1010,8 +1016,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1028,13 +1033,22 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image, channel): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + if channel not in s._color_channels: + return "Invalid color channel: {}".format(channel) + + return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] diff --git a/web/scripts/app.js b/web/scripts/app.js index b3e88d46..a161bf40 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -172,7 +172,7 @@ export class ComfyApp { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); if (prop) { - prop.value = value; + prop.callback(value); } }); }