Cleaner code.

This commit is contained in:
comfyanonymous 2024-08-08 20:07:09 -04:00
parent 037c38eb0f
commit 11200de970
1 changed files with 9 additions and 9 deletions

View File

@ -8,6 +8,7 @@ from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
import comfy.ops import comfy.ops
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list): def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__() super().__init__()
@ -174,20 +175,19 @@ class DoubleStreamBlock(nn.Module):
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
q = torch.cat((txt_q, img_q), dim=2) attn = attention(torch.cat((txt_q, img_q), dim=2),
k = torch.cat((txt_k, img_k), dim=2) torch.cat((txt_k, img_k), dim=2),
v = torch.cat((txt_v, img_v), dim=2) torch.cat((txt_v, img_v), dim=2), pe=pe)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks # calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img += img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks # calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504) txt = txt.clip(-65504, 65504)
@ -243,7 +243,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x = x + mod.gate * output x += mod.gate * output
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = x.clip(-65504, 65504) x = x.clip(-65504, 65504)
return x return x