Allow attn_mask in attention_pytorch.

This commit is contained in:
comfyanonymous 2023-10-11 20:24:17 -04:00
parent 1a4bd9e9a6
commit ac7d8cfa87
1 changed files with 1 additions and 1 deletions

View File

@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError