diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index b0ae2dfa..f39daa00 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -113,19 +113,21 @@ class PorterDuffImageComposite: src_image = source[i] dst_image = destination[i] + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + src_alpha = source_alpha[i].unsqueeze(2) dst_alpha = destination_alpha[i].unsqueeze(2) - if dst_alpha.shape != dst_image.shape: - upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2) + if dst_alpha.shape[:2] != dst_image.shape[:2]: + upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_image.shape != dst_image.shape: - upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_alpha.shape != dst_alpha.shape: - upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) @@ -177,7 +179,7 @@ class JoinImageWithAlpha: out_images = [] for i in range(batch_size): - out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2)) + out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) result = (torch.stack(out_images),) return result