From 8b90e50979b0d33e1f8d10d5c938361f59f95474 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Nov 2024 07:10:43 -0500 Subject: [PATCH] Properly handle and reshape masks when used on 3d latents. --- comfy/sampler_helpers.py | 8 ++------ comfy/utils.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 4a2ec123..1879e670 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,14 +1,10 @@ import torch import comfy.model_management import comfy.conds +import comfy.utils def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) - noise_mask = noise_mask.to(device) - return noise_mask + return comfy.utils.reshape_mask(noise_mask, shape).to(device) def get_models_from_cond(cond, model_type): models = [] diff --git a/comfy/utils.py b/comfy/utils.py index cc92e111..3c5d06a4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -848,3 +848,24 @@ class ProgressBar: def update(self, value): self.update_absolute(self.current + value) + +def reshape_mask(input_mask, output_shape): + dims = len(output_shape) - 2 + + if dims == 1: + scale_mode = "linear" + + if dims == 2: + mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1])) + scale_mode = "bilinear" + + if dims == 3: + if len(input_mask.shape) < 5: + mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1])) + scale_mode = "trilinear" + + mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode) + if mask.shape[1] < output_shape[1]: + mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] + mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0]) + return mask