Add a way for nodes to validate their own inputs.
This commit is contained in:
parent
f7a8218814
commit
ccad603b2e
21
execution.py
21
execution.py
|
@ -11,7 +11,6 @@ import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
|
@ -250,14 +249,15 @@ def validate_inputs(prompt, item):
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in info[1] and val > info[1]["max"]:
|
||||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]")
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
if is_annotated_path:
|
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
if not folder_paths.exists_annotated_filepath(val):
|
if ret != True:
|
||||||
return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val))
|
return (False, "{}, {}".format(class_type, ret))
|
||||||
|
else:
|
||||||
elif val not in type_input:
|
if isinstance(type_input, list):
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
if val not in type_input:
|
||||||
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||||
return (True, "")
|
return (True, "")
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
|
@ -279,7 +279,8 @@ def validate_prompt(prompt):
|
||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reason = m[1]
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
valid = False
|
valid = False
|
||||||
reason = "Parsing error"
|
reason = "Parsing error"
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ def get_directory_by_type(type_name):
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
def touch_annotated_filepath(name):
|
def annotated_filepath(name):
|
||||||
if name.endswith("[output]"):
|
if name.endswith("[output]"):
|
||||||
base_dir = get_output_directory()
|
base_dir = get_output_directory()
|
||||||
name = name[:-9]
|
name = name[:-9]
|
||||||
|
@ -88,7 +88,7 @@ def touch_annotated_filepath(name):
|
||||||
|
|
||||||
|
|
||||||
def get_annotated_filepath(name, default_dir=None):
|
def get_annotated_filepath(name, default_dir=None):
|
||||||
name, base_dir = touch_annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
if default_dir is not None:
|
if default_dir is not None:
|
||||||
|
@ -100,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None):
|
||||||
|
|
||||||
|
|
||||||
def exists_annotated_filepath(name):
|
def exists_annotated_filepath(name):
|
||||||
name, base_dir = touch_annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
base_dir = get_input_directory() # fallback path
|
base_dir = get_input_directory() # fallback path
|
||||||
|
|
32
nodes.py
32
nodes.py
|
@ -974,8 +974,7 @@ class LoadImage:
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
@ -989,20 +988,27 @@ class LoadImage:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(input_dir)), ),
|
{"image": (sorted(os.listdir(input_dir)), ),
|
||||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
"channel": (s._color_channels, ),}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask"
|
CATEGORY = "mask"
|
||||||
|
@ -1010,8 +1016,7 @@ class LoadImageMask:
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
if i.getbands() != ("R", "G", "B", "A"):
|
||||||
i = i.convert("RGBA")
|
i = i.convert("RGBA")
|
||||||
|
@ -1028,13 +1033,22 @@ class LoadImageMask:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image, channel):
|
def IS_CHANGED(s, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image, channel):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
if channel not in s._color_channels:
|
||||||
|
return "Invalid color channel: {}".format(channel)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
|
@ -172,7 +172,7 @@ export class ComfyApp {
|
||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
if (prop) {
|
if (prop) {
|
||||||
prop.value = value;
|
prop.callback(value);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue