From 20447e9ec92b7e7e3544a6fd2932c31c90333991 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Tue, 4 Jun 2024 23:37:11 +0300 Subject: [PATCH] Fix alpha in PorterDuffImageComposite. (#3411) There were two bugs in PorterDuffImageComposite. The first one is the fact that it uses the mask input directly as alpha, missing the conversion (`1-a`). The fix is similar to c16f5744. The second one is that all color composition formulas assume alpha premultiplied values, while the input is not premultiplied. This change fixes both of these issue. --- comfy_extras/nodes_compositing.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 181b36ed..48fe5e3d 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -28,6 +28,14 @@ class PorterDuffMode(Enum): def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images if mode == PorterDuffMode.ADD: out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) out_image = torch.clamp(src_image + dst_image, 0, 1) @@ -35,7 +43,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = torch.zeros_like(dst_alpha) out_image = torch.zeros_like(dst_image) elif mode == PorterDuffMode.DARKEN: - out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) elif mode == PorterDuffMode.DST: out_alpha = dst_alpha @@ -84,8 +92,13 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image else: - out_alpha = None - out_image = None + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha return out_image, out_alpha