Update attention.py to resolve conflicts with main branch updates
This commit is contained in:
parent
438b1ea399
commit
efd5893913
|
@ -86,12 +86,16 @@ class FeedForward(nn.Module):
|
|||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
|
@ -309,12 +313,15 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
disabled_xformers = False
|
||||
|
||||
|
@ -329,10 +336,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
|
@ -347,16 +360,22 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||
)
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
cast_to_type = attn_precision if attn_precision is not None else q.dtype
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.to(dtype=cast_to_type), (q, k, v),
|
||||
)
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).to(dtype=cast_to_type),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
|
|
Loading…
Reference in New Issue