Add pytorch attention support to VAE.
This commit is contained in:
parent
a256a2abde
commit
83f23f82b8
|
@ -479,23 +479,19 @@ class CrossAttentionPytorch(nn.Module):
|
|||
return self.to_out(out)
|
||||
|
||||
import sys
|
||||
if model_management.xformers_enabled() == False:
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
CrossAttention = MemoryEfficientCrossAttention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
CrossAttention = CrossAttentionPytorch
|
||||
else:
|
||||
if "--use-split-cross-attention" in sys.argv:
|
||||
print("Using split optimization for cross attention")
|
||||
CrossAttention = CrossAttentionDoggettx
|
||||
else:
|
||||
if "--use-pytorch-cross-attention" in sys.argv:
|
||||
print("Using pytorch cross attention")
|
||||
torch.backends.cuda.enable_math_sdp(False)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
CrossAttention = CrossAttentionPytorch
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
CrossAttention = CrossAttentionBirchSan
|
||||
else:
|
||||
print("Using xformers cross attention")
|
||||
CrossAttention = MemoryEfficientCrossAttention
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
CrossAttention = CrossAttentionBirchSan
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
|
|
@ -299,6 +299,64 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
|
@ -313,6 +371,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if model_management.xformers_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-pytorch"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
|
@ -320,6 +380,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-pytorch":
|
||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
|
|
|
@ -41,6 +41,14 @@ else:
|
|||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if "--use-pytorch-cross-attention" in sys.argv:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
|
||||
if "--cpu" in sys.argv:
|
||||
vram_state = CPU
|
||||
|
@ -175,6 +183,9 @@ def xformers_enabled():
|
|||
return False
|
||||
return XFORMERS_IS_AVAILBLE
|
||||
|
||||
def pytorch_attention_enabled():
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
|
1
main.py
1
main.py
|
@ -15,6 +15,7 @@ if __name__ == "__main__":
|
|||
print("\t--port 8188\t\t\tSet the listen port.")
|
||||
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
|
||||
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
|
||||
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
|
||||
print("\t--disable-xformers\t\tdisables xformers")
|
||||
print()
|
||||
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
|
||||
|
|
Loading…
Reference in New Issue