diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 181b36ed..48fe5e3d 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -28,6 +28,14 @@ class PorterDuffMode(Enum): def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images if mode == PorterDuffMode.ADD: out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) out_image = torch.clamp(src_image + dst_image, 0, 1) @@ -35,7 +43,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = torch.zeros_like(dst_alpha) out_image = torch.zeros_like(dst_image) elif mode == PorterDuffMode.DARKEN: - out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) elif mode == PorterDuffMode.DST: out_alpha = dst_alpha @@ -84,8 +92,13 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image else: - out_alpha = None - out_image = None + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha return out_image, out_alpha