diff --git a/nodes.py b/nodes.py index 1f495900..40e7558f 100644 --- a/nodes.py +++ b/nodes.py @@ -919,6 +919,7 @@ class ImagePadForOutpaint: "top": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), "right": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), "bottom": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), + "feathering": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), } } @@ -927,7 +928,7 @@ class ImagePadForOutpaint: CATEGORY = "image" - def expand_image(self, image, left, top, right, bottom): + def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() new_image = torch.zeros( @@ -940,10 +941,30 @@ class ImagePadForOutpaint: (d2 + top + bottom, d3 + left + right), dtype=torch.float32, ) - mask[top:top + d2, left:left + d3] = torch.zeros( - (d2, d3), - dtype=torch.float32, - ) + + if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3: + # distances to border + mi, mj = torch.meshgrid( + torch.arange(d2, dtype=torch.float32), + torch.arange(d3, dtype=torch.float32), + indexing='ij', + ) + distances = torch.minimum( + torch.minimum(mi, mj), + torch.minimum(d2 - 1 - mi, d3 - 1 - mj), + ) + # convert distances to square falloff from 1 to 0 + t = (feathering - distances) / feathering + t.clamp_(min=0) + t.square_() + + mask[top:top + d2, left:left + d3] = t + else: + mask[top:top + d2, left:left + d3] = torch.zeros( + (d2, d3), + dtype=torch.float32, + ) + return (new_image, mask)