Allow joining a batch of images with a single mask
Previously, JoinImageWithAlpha required a batch of images to match a batch of masks. But for some use cases it's easier to provide a batch of images and a single mask. This change automatically repeats the mask for all images in a batch. In the same spirit, PorterDuffImageComposite will now allow a single mask for a batch of images (for both src and dst). But also, PorterDuffImageComposite will apply the same logic to src and dst: if src contains one image, and dst is a batch it will repeat src to match dst (or the opposite).
This commit is contained in:
parent
565eb6d176
commit
ed0c0d1c26
|
@ -107,10 +107,24 @@ class PorterDuffImageComposite:
|
|||
CATEGORY = "mask/compositing"
|
||||
|
||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||
batch_size = min(len(source), len(destination))
|
||||
if batch_size == 1:
|
||||
if len(source) != 1:
|
||||
batch_size = len(source)
|
||||
elif len(destination) != 1:
|
||||
batch_size = len(destination)
|
||||
out_images = []
|
||||
out_alphas = []
|
||||
|
||||
if batch_size != 1:
|
||||
if len(source) == 1:
|
||||
source = source.repeat(batch_size, 1, 1, 1)
|
||||
if len(destination) == 1:
|
||||
destination = destination.repeat(batch_size, 1, 1, 1)
|
||||
if len(source_alpha) == 1:
|
||||
source_alpha = source_alpha.repeat(batch_size, 1, 1)
|
||||
if len(destination_alpha) == 1:
|
||||
destination_alpha = destination_alpha.repeat(batch_size, 1, 1)
|
||||
for i in range(batch_size):
|
||||
src_image = source[i]
|
||||
dst_image = destination[i]
|
||||
|
@ -180,6 +194,8 @@ class JoinImageWithAlpha:
|
|||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
if len(alpha) == 1 and batch_size != 1:
|
||||
alpha = alpha.repeat(batch_size, 1, 1, 1)
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
|
Loading…
Reference in New Issue