Update attention.py to resolve conflicts with main branch updates

This commit is contained in:
shawnington 2024-06-16 11:43:41 -04:00 committed by GitHub
parent 438b1ea399
commit efd5893913
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 38 additions and 19 deletions

View File

@ -86,12 +86,16 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
b, _, dim_head = q.shape
dim_head //= heads
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
@ -309,12 +313,15 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
b, _, dim_head = q.shape
dim_head //= heads
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False
@ -329,10 +336,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
(q, k, v),
)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head).to(dtype=cast_to_type),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
@ -347,16 +360,22 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
)
return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
cast_to_type = attn_precision if attn_precision is not None else q.dtype
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
(q, k, v),
)
if skip_reshape:
b, _, _, dim_head = q.shape
q, k, v = map(
lambda t: t.to(dtype=cast_to_type), (q, k, v),
)
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (