From 214ca7197ef753bce3b40f642c6775d919568c2f Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Sun, 24 Sep 2023 00:12:55 +0200 Subject: [PATCH] Corrected joining images with alpha (for RGBA input), and checking scaling conditions --- comfy_extras/nodes_compositing.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index b0ae2dfa..f39daa00 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -113,19 +113,21 @@ class PorterDuffImageComposite: src_image = source[i] dst_image = destination[i] + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + src_alpha = source_alpha[i].unsqueeze(2) dst_alpha = destination_alpha[i].unsqueeze(2) - if dst_alpha.shape != dst_image.shape: - upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2) + if dst_alpha.shape[:2] != dst_image.shape[:2]: + upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_image.shape != dst_image.shape: - upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_alpha.shape != dst_alpha.shape: - upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) @@ -177,7 +179,7 @@ class JoinImageWithAlpha: out_images = [] for i in range(batch_size): - out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2)) + out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) result = (torch.stack(out_images),) return result