Default to sampling entire image
By default, when applying a mask to a condition, the entire image will still be used for sampling. The new "set_area_to_bounds" option on the node will allow the user to automatically limit conditioning to the bounds of the mask. I've also removed the dependency on torchvision for calculating bounding boxes. I've taken the opportunity to fix some frustrating details in the other version: 1. An all-0 mask will no longer cause an error 2. Indices are returned as integers instead of floats so they can be used to index into tensors.
This commit is contained in:
parent
e214c917ae
commit
af02393c2a
|
@ -6,7 +6,6 @@ import contextlib
|
|||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
from torchvision.ops import masks_to_boxes
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
|
@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
if mask.shape[0] != input_x.shape[0]:
|
||||
mask = mask.repeat(input_x.shape[0], 1, 1)
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image):
|
|||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
|
||||
b = masks.shape[0]
|
||||
|
||||
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
|
||||
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
|
||||
for i in range(b):
|
||||
mask = masks[i]
|
||||
if mask.numel() == 0:
|
||||
continue
|
||||
if torch.max(mask != 0) == False:
|
||||
is_empty[i] = True
|
||||
continue
|
||||
y, x = torch.where(mask)
|
||||
bounding_boxes[i, 0] = torch.min(x)
|
||||
bounding_boxes[i, 1] = torch.min(y)
|
||||
bounding_boxes[i, 2] = torch.max(x)
|
||||
bounding_boxes[i, 3] = torch.max(y)
|
||||
|
||||
return bounding_boxes, is_empty
|
||||
|
||||
def resolve_cond_masks(conditions, h, w, device):
|
||||
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
|
@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device):
|
|||
if mask.shape[2] != h or mask.shape[3] != w:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||
|
||||
if 'area' not in modified:
|
||||
if modified.get("set_area_to_bounds", False):
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
if torch.max(bounds) == 0:
|
||||
# Handle the edge-case of an all black mask (where masks_to_boxes would error)
|
||||
area = (0, 0, 0, 0)
|
||||
boxes, is_empty = get_mask_aabb(bounds)
|
||||
if is_empty[0]:
|
||||
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
|
||||
modified['area'] = (8, 8, 0, 0)
|
||||
else:
|
||||
box = masks_to_boxes(bounds)[0].type(torch.int)
|
||||
box = boxes[0]
|
||||
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||
# Make sure the height and width are divisible by 8
|
||||
if X % 8 != 0:
|
||||
|
@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device):
|
|||
H = H + (8 - (H % 8))
|
||||
if W % 8 != 0:
|
||||
W = W + (8 - (W % 8))
|
||||
area = (int(H), int(W), int(Y), (X))
|
||||
modified['area'] = area
|
||||
area = (int(H), int(W), int(Y), int(X))
|
||||
modified['area'] = area
|
||||
|
||||
modified['mask'] = mask
|
||||
conditions[i] = [c[0], modified]
|
||||
|
|
4
nodes.py
4
nodes.py
|
@ -90,6 +90,7 @@ class ConditioningSetMask:
|
|||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"mask": ("MASK", ),
|
||||
"set_area_to_bounds": ([False, True],),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
|
@ -97,7 +98,7 @@ class ConditioningSetMask:
|
|||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
c = []
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
@ -105,6 +106,7 @@ class ConditioningSetMask:
|
|||
n = [t[0], t[1].copy()]
|
||||
_, h, w = mask.shape
|
||||
n[1]['mask'] = mask
|
||||
n[1]['set_area_to_bounds'] = set_area_to_bounds
|
||||
n[1]['strength'] = strength
|
||||
n[1]['min_sigma'] = min_sigma
|
||||
n[1]['max_sigma'] = max_sigma
|
||||
|
|
Loading…
Reference in New Issue