Allow attn_mask in attention_pytorch.
This commit is contained in:
parent
1a4bd9e9a6
commit
ac7d8cfa87
|
@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
|||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue