From ed0c0d1c26b8a935a8625e5b198f4f27b173e264 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Mon, 6 May 2024 12:33:16 +0300 Subject: [PATCH] Allow joining a batch of images with a single mask Previously, JoinImageWithAlpha required a batch of images to match a batch of masks. But for some use cases it's easier to provide a batch of images and a single mask. This change automatically repeats the mask for all images in a batch. In the same spirit, PorterDuffImageComposite will now allow a single mask for a batch of images (for both src and dst). But also, PorterDuffImageComposite will apply the same logic to src and dst: if src contains one image, and dst is a batch it will repeat src to match dst (or the opposite). --- comfy_extras/nodes_compositing.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 181b36ed..91b36022 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -107,10 +107,24 @@ class PorterDuffImageComposite: CATEGORY = "mask/compositing" def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): - batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) + batch_size = min(len(source), len(destination)) + if batch_size == 1: + if len(source) != 1: + batch_size = len(source) + elif len(destination) != 1: + batch_size = len(destination) out_images = [] out_alphas = [] + if batch_size != 1: + if len(source) == 1: + source = source.repeat(batch_size, 1, 1, 1) + if len(destination) == 1: + destination = destination.repeat(batch_size, 1, 1, 1) + if len(source_alpha) == 1: + source_alpha = source_alpha.repeat(batch_size, 1, 1) + if len(destination_alpha) == 1: + destination_alpha = destination_alpha.repeat(batch_size, 1, 1) for i in range(batch_size): src_image = source[i] dst_image = destination[i] @@ -180,6 +194,8 @@ class JoinImageWithAlpha: batch_size = min(len(image), len(alpha)) out_images = [] + if len(alpha) == 1 and batch_size != 1: + alpha = alpha.repeat(batch_size, 1, 1, 1) alpha = 1.0 - resize_mask(alpha, image.shape[1:]) for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))