Corrected joining images with alpha (for RGBA input), and checking scaling conditions
This commit is contained in:
parent
585fb0475b
commit
214ca7197e
|
@ -113,19 +113,21 @@ class PorterDuffImageComposite:
|
||||||
src_image = source[i]
|
src_image = source[i]
|
||||||
dst_image = destination[i]
|
dst_image = destination[i]
|
||||||
|
|
||||||
|
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
||||||
|
|
||||||
src_alpha = source_alpha[i].unsqueeze(2)
|
src_alpha = source_alpha[i].unsqueeze(2)
|
||||||
dst_alpha = destination_alpha[i].unsqueeze(2)
|
dst_alpha = destination_alpha[i].unsqueeze(2)
|
||||||
|
|
||||||
if dst_alpha.shape != dst_image.shape:
|
if dst_alpha.shape[:2] != dst_image.shape[:2]:
|
||||||
upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2)
|
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
||||||
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
if src_image.shape != dst_image.shape:
|
if src_image.shape != dst_image.shape:
|
||||||
upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2)
|
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
||||||
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
if src_alpha.shape != dst_alpha.shape:
|
if src_alpha.shape != dst_alpha.shape:
|
||||||
upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2)
|
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
||||||
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
|
|
||||||
|
@ -177,7 +179,7 @@ class JoinImageWithAlpha:
|
||||||
out_images = []
|
out_images = []
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
result = (torch.stack(out_images),)
|
result = (torch.stack(out_images),)
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Reference in New Issue