Added support for Porter-Duff image compositing
This commit is contained in:
parent
9bfec2bdbf
commit
d06cd2805d
|
@ -0,0 +1,239 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class PorterDuffMode(Enum):
|
||||||
|
ADD = 0
|
||||||
|
CLEAR = 1
|
||||||
|
DARKEN = 2
|
||||||
|
DST = 3
|
||||||
|
DST_ATOP = 4
|
||||||
|
DST_IN = 5
|
||||||
|
DST_OUT = 6
|
||||||
|
DST_OVER = 7
|
||||||
|
LIGHTEN = 8
|
||||||
|
MULTIPLY = 9
|
||||||
|
OVERLAY = 10
|
||||||
|
SCREEN = 11
|
||||||
|
SRC = 12
|
||||||
|
SRC_ATOP = 13
|
||||||
|
SRC_IN = 14
|
||||||
|
SRC_OUT = 15
|
||||||
|
SRC_OVER = 16
|
||||||
|
XOR = 17
|
||||||
|
|
||||||
|
|
||||||
|
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
||||||
|
if mode == PorterDuffMode.ADD:
|
||||||
|
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
||||||
|
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
||||||
|
elif mode == PorterDuffMode.CLEAR:
|
||||||
|
out_alpha = torch.zeros_like(dst_alpha)
|
||||||
|
out_image = torch.zeros_like(dst_image)
|
||||||
|
elif mode == PorterDuffMode.DARKEN:
|
||||||
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
||||||
|
elif mode == PorterDuffMode.DST:
|
||||||
|
out_alpha = dst_alpha
|
||||||
|
out_image = dst_image
|
||||||
|
elif mode == PorterDuffMode.DST_ATOP:
|
||||||
|
out_alpha = src_alpha
|
||||||
|
out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
|
||||||
|
elif mode == PorterDuffMode.DST_IN:
|
||||||
|
out_alpha = src_alpha * dst_alpha
|
||||||
|
out_image = dst_image * src_alpha
|
||||||
|
elif mode == PorterDuffMode.DST_OUT:
|
||||||
|
out_alpha = (1 - src_alpha) * dst_alpha
|
||||||
|
out_image = (1 - src_alpha) * dst_image
|
||||||
|
elif mode == PorterDuffMode.DST_OVER:
|
||||||
|
out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
|
||||||
|
out_image = dst_image + (1 - dst_alpha) * src_image
|
||||||
|
elif mode == PorterDuffMode.LIGHTEN:
|
||||||
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
|
||||||
|
elif mode == PorterDuffMode.MULTIPLY:
|
||||||
|
out_alpha = src_alpha * dst_alpha
|
||||||
|
out_image = src_image * dst_image
|
||||||
|
elif mode == PorterDuffMode.OVERLAY:
|
||||||
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
|
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
||||||
|
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
||||||
|
elif mode == PorterDuffMode.SCREEN:
|
||||||
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
|
out_image = src_image + dst_image - src_image * dst_image
|
||||||
|
elif mode == PorterDuffMode.SRC:
|
||||||
|
out_alpha = src_alpha
|
||||||
|
out_image = src_image
|
||||||
|
elif mode == PorterDuffMode.SRC_ATOP:
|
||||||
|
out_alpha = dst_alpha
|
||||||
|
out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
|
||||||
|
elif mode == PorterDuffMode.SRC_IN:
|
||||||
|
out_alpha = src_alpha * dst_alpha
|
||||||
|
out_image = src_image * dst_alpha
|
||||||
|
elif mode == PorterDuffMode.SRC_OUT:
|
||||||
|
out_alpha = (1 - dst_alpha) * src_alpha
|
||||||
|
out_image = (1 - dst_alpha) * src_image
|
||||||
|
elif mode == PorterDuffMode.SRC_OVER:
|
||||||
|
out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
|
||||||
|
out_image = src_image + (1 - src_alpha) * dst_image
|
||||||
|
elif mode == PorterDuffMode.XOR:
|
||||||
|
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
||||||
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
||||||
|
else:
|
||||||
|
out_alpha = None
|
||||||
|
out_image = None
|
||||||
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
|
class PorterDuffImageComposite:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"source": ("IMAGE",),
|
||||||
|
"source_alpha": ("ALPHA",),
|
||||||
|
"destination": ("IMAGE",),
|
||||||
|
"destination_alpha": ("ALPHA",),
|
||||||
|
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "ALPHA")
|
||||||
|
FUNCTION = "composite"
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
|
||||||
|
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
||||||
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||||
|
out_images = []
|
||||||
|
out_alphas = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
src_image = source[i]
|
||||||
|
dst_image = destination[i]
|
||||||
|
|
||||||
|
src_alpha = source_alpha[i].unsqueeze(2)
|
||||||
|
dst_alpha = destination_alpha[i].unsqueeze(2)
|
||||||
|
|
||||||
|
if dst_alpha.shape != dst_image.shape:
|
||||||
|
upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2)
|
||||||
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
||||||
|
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
|
if src_image.shape != dst_image.shape:
|
||||||
|
upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2)
|
||||||
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
||||||
|
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
|
if src_alpha.shape != dst_alpha.shape:
|
||||||
|
upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2)
|
||||||
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
||||||
|
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
|
|
||||||
|
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
||||||
|
|
||||||
|
out_images.append(out_image)
|
||||||
|
out_alphas.append(out_alpha.squeeze(2))
|
||||||
|
|
||||||
|
result = (torch.stack(out_images), torch.stack(out_alphas))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SplitImageWithAlpha:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
RETURN_TYPES = ("IMAGE", "ALPHA")
|
||||||
|
FUNCTION = "split_image_with_alpha"
|
||||||
|
|
||||||
|
def split_image_with_alpha(self, image: torch.Tensor):
|
||||||
|
out_images = [i[:,:,:3] for i in image]
|
||||||
|
out_alphas = [i[:,:,3] for i in image]
|
||||||
|
result = (torch.stack(out_images), torch.stack(out_alphas))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class JoinImageWithAlpha:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"alpha": ("ALPHA",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "join_image_with_alpha"
|
||||||
|
|
||||||
|
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
||||||
|
batch_size = min(len(image), len(alpha))
|
||||||
|
out_images = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
|
result = (torch.stack(out_images),)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertAlphaToImage:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"alpha": ("ALPHA",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "alpha_to_image"
|
||||||
|
|
||||||
|
def alpha_to_image(self, alpha):
|
||||||
|
result = alpha.reshape((-1, 1, alpha.shape[-2], alpha.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertImageToAlpha:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"channel": (["red", "green", "blue", "alpha"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
RETURN_TYPES = ("ALPHA",)
|
||||||
|
FUNCTION = "image_to_alpha"
|
||||||
|
|
||||||
|
def image_to_alpha(self, image, channel):
|
||||||
|
channels = ["red", "green", "blue", "alpha"]
|
||||||
|
alpha = image[0, :, :, channels.index(channel)]
|
||||||
|
return (alpha,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PorterDuffImageComposite": PorterDuffImageComposite,
|
||||||
|
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||||
|
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||||
|
"ConvertAlphaToImage": ConvertAlphaToImage,
|
||||||
|
"ConvertImageToAlpha": ConvertImageToAlpha,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
||||||
|
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||||
|
"JoinImageWithAlpha": "Join Image with Alpha",
|
||||||
|
"ConvertAlphaToImage": "Convert Alpha to Image",
|
||||||
|
"ConvertImageToAlpha": "Convert Image to Alpha",
|
||||||
|
}
|
28
nodes.py
28
nodes.py
|
@ -1372,6 +1372,31 @@ class LoadImage:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class LoadImageWithAlpha(LoadImage):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
|
return {"required":
|
||||||
|
{"image": (sorted(files), {"image_upload": True})},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "compositing"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "ALPHA")
|
||||||
|
|
||||||
|
FUNCTION = "load_image"
|
||||||
|
def load_image(self, image):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
i = Image.open(image_path)
|
||||||
|
i = ImageOps.exif_transpose(i)
|
||||||
|
image = i.convert("RGBA")
|
||||||
|
alpha = np.array(image.getchannel("A")).astype(np.float32) / 255.0
|
||||||
|
alpha = torch.from_numpy(alpha)[None,]
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = torch.from_numpy(image)[None,]
|
||||||
|
return (image, alpha)
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
_color_channels = ["alpha", "red", "green", "blue"]
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1606,6 +1631,7 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"SaveImage": SaveImage,
|
"SaveImage": SaveImage,
|
||||||
"PreviewImage": PreviewImage,
|
"PreviewImage": PreviewImage,
|
||||||
"LoadImage": LoadImage,
|
"LoadImage": LoadImage,
|
||||||
|
"LoadImageWithAlpha": LoadImageWithAlpha,
|
||||||
"LoadImageMask": LoadImageMask,
|
"LoadImageMask": LoadImageMask,
|
||||||
"ImageScale": ImageScale,
|
"ImageScale": ImageScale,
|
||||||
"ImageScaleBy": ImageScaleBy,
|
"ImageScaleBy": ImageScaleBy,
|
||||||
|
@ -1702,6 +1728,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"SaveImage": "Save Image",
|
"SaveImage": "Save Image",
|
||||||
"PreviewImage": "Preview Image",
|
"PreviewImage": "Preview Image",
|
||||||
"LoadImage": "Load Image",
|
"LoadImage": "Load Image",
|
||||||
|
"LoadImageWithAlpha": "Load Image with Alpha",
|
||||||
"LoadImageMask": "Load Image (as Mask)",
|
"LoadImageMask": "Load Image (as Mask)",
|
||||||
"ImageScale": "Upscale Image",
|
"ImageScale": "Upscale Image",
|
||||||
"ImageScaleBy": "Upscale Image By",
|
"ImageScaleBy": "Upscale Image By",
|
||||||
|
@ -1788,6 +1815,7 @@ def init_custom_nodes():
|
||||||
"nodes_upscale_model.py",
|
"nodes_upscale_model.py",
|
||||||
"nodes_post_processing.py",
|
"nodes_post_processing.py",
|
||||||
"nodes_mask.py",
|
"nodes_mask.py",
|
||||||
|
"nodes_compositing.py",
|
||||||
"nodes_rebatch.py",
|
"nodes_rebatch.py",
|
||||||
"nodes_model_merging.py",
|
"nodes_model_merging.py",
|
||||||
"nodes_tomesd.py",
|
"nodes_tomesd.py",
|
||||||
|
|
Loading…
Reference in New Issue