From 797c4e8d3b56559bb205ff8aab97d97dca424b9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:07:21 -0400 Subject: [PATCH] Simplify and improve some vae attention code. --- comfy/ldm/modules/diffusionmodules/model.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60..05caf731 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -331,25 +331,13 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.transpose(2, 3).reshape(B, C, H, W) out = self.proj_out(out) return x+out