Properly handle and reshape masks when used on 3d latents.
This commit is contained in:
parent
6ee066a14f
commit
8b90e50979
|
@ -1,14 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
|
||||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
|
||||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
|
||||||
noise_mask = noise_mask.to(device)
|
|
||||||
return noise_mask
|
|
||||||
|
|
||||||
def get_models_from_cond(cond, model_type):
|
def get_models_from_cond(cond, model_type):
|
||||||
models = []
|
models = []
|
||||||
|
|
|
@ -848,3 +848,24 @@ class ProgressBar:
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|
||||||
|
def reshape_mask(input_mask, output_shape):
|
||||||
|
dims = len(output_shape) - 2
|
||||||
|
|
||||||
|
if dims == 1:
|
||||||
|
scale_mode = "linear"
|
||||||
|
|
||||||
|
if dims == 2:
|
||||||
|
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||||
|
scale_mode = "bilinear"
|
||||||
|
|
||||||
|
if dims == 3:
|
||||||
|
if len(input_mask.shape) < 5:
|
||||||
|
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||||
|
scale_mode = "trilinear"
|
||||||
|
|
||||||
|
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
|
||||||
|
if mask.shape[1] < output_shape[1]:
|
||||||
|
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||||
|
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||||
|
return mask
|
||||||
|
|
Loading…
Reference in New Issue