JoinImageWithAlpha now works with any mask shape.
This commit is contained in:
parent
9212bea87c
commit
0e763e880f
|
@ -3,6 +3,8 @@ import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
def resize_mask(mask, shape):
|
||||||
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||||
|
|
||||||
class PorterDuffMode(Enum):
|
class PorterDuffMode(Enum):
|
||||||
ADD = 0
|
ADD = 0
|
||||||
|
@ -178,6 +180,7 @@ class JoinImageWithAlpha:
|
||||||
batch_size = min(len(image), len(alpha))
|
batch_size = min(len(image), len(alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
|
|
||||||
|
alpha = resize_mask(alpha, image.shape[1:])
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue