diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3f543abd..7b4ee215 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,10 +372,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh ) if mask is not None: - pad = 8 - q.shape[1] % 8 - mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) - mask_out[:, :, :mask.shape[-1]] = mask - mask = mask_out[:, :, :mask.shape[-1]] + pad = 8 - mask.shape[-1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device) + mask_out[..., :mask.shape[-1]] = mask + mask = mask_out[..., :mask.shape[-1]] out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)