diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index ab17fc50..60feea0d 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -59,23 +59,41 @@ class LatentCompositeMasked: class MaskToImage: @classmethod - def INPUT_TYPES(cls): + def INPUT_TYPES(s): return { - "required": { - "mask": ("MASK",), - } + "required": { + "mask": ("MASK",), + } } CATEGORY = "mask" RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" - FUNCTION = "convert" + def mask_to_image(self, mask): + result = mask[None, :, :, None].expand(-1, -1, -1, 3) + return (result,) - def convert(self, mask): - image = torch.cat([torch.reshape(mask.clone(), [1, mask.shape[0], mask.shape[1], 1,])] * 3, 3) +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue"],), + } + } - return (image,) + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue"] + mask = image[0, :, :, channels.index(channel)] + return (mask,) class SolidMask: @classmethod @@ -231,6 +249,7 @@ class FeatherMask: NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, "MaskToImage": MaskToImage, + "ImageToMask": ImageToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask, @@ -238,3 +257,7 @@ NODE_CLASS_MAPPINGS = { "FeatherMask": FeatherMask, } +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", +} diff --git a/comfy_extras/nodes_mask_conversion.py b/comfy_extras/nodes_mask_conversion.py deleted file mode 100644 index 04dcbd0d..00000000 --- a/comfy_extras/nodes_mask_conversion.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image - -import comfy.utils - -class ImageToMask: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue"],), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): - channels = ["red", "green", "blue"] - mask = image[0, :, :, channels.index(channel)] - return (mask,) - -class MaskToImage: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } - - CATEGORY = "image" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) - return (result,) - -NODE_CLASS_MAPPINGS = { - "ImageToMask": ImageToMask, - "MaskToImage": MaskToImage, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} diff --git a/nodes.py b/nodes.py index aff03dd4..6468ac6b 100644 --- a/nodes.py +++ b/nodes.py @@ -1193,4 +1193,3 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask_conversion.py"))