diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 5adb468a..43f623a6 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -125,6 +125,27 @@ class ImageToMask: mask = image[0, :, :, channels.index(channel)] return (mask,) +class ImageColorToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, color): + temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int) + temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2] + mask = torch.where(temp == color, 255, 0).float() + return (mask,) + class SolidMask: @classmethod def INPUT_TYPES(cls): @@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = { "ImageCompositeMasked": ImageCompositeMasked, "MaskToImage": MaskToImage, "ImageToMask": ImageToMask, + "ImageColorToMask": ImageColorToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask,