Support batches of masks in mask composite nodes.
This commit is contained in:
parent
ba7dfd60f2
commit
046b4fe0ee
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.ndimage import grey_dilation
|
from scipy.ndimage import grey_dilation
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
|
@ -8,6 +9,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||||
if resize_source:
|
if resize_source:
|
||||||
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||||
|
|
||||||
|
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
|
||||||
|
|
||||||
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
||||||
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
||||||
|
|
||||||
|
@ -18,8 +21,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||||
mask = torch.ones_like(source)
|
mask = torch.ones_like(source)
|
||||||
else:
|
else:
|
||||||
mask = mask.clone()
|
mask = mask.clone()
|
||||||
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||||
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
|
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
||||||
|
|
||||||
# calculate the bounds of the source that will be overlapping the destination
|
# calculate the bounds of the source that will be overlapping the destination
|
||||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||||
|
@ -122,7 +125,7 @@ class ImageToMask:
|
||||||
|
|
||||||
def image_to_mask(self, image, channel):
|
def image_to_mask(self, image, channel):
|
||||||
channels = ["red", "green", "blue"]
|
channels = ["red", "green", "blue"]
|
||||||
mask = image[0, :, :, channels.index(channel)]
|
mask = image[:, :, :, channels.index(channel)]
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
class ImageColorToMask:
|
class ImageColorToMask:
|
||||||
|
|
Loading…
Reference in New Issue