Cleaner code.
This commit is contained in:
parent
037c38eb0f
commit
11200de970
|
@ -8,6 +8,7 @@ from torch import Tensor, nn
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -174,20 +175,19 @@ class DoubleStreamBlock(nn.Module):
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
# run actual attention
|
# run actual attention
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||||
|
|
||||||
attn = attention(q, k, v, pe=pe)
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = txt.clip(-65504, 65504)
|
||||||
|
@ -243,7 +243,7 @@ class SingleStreamBlock(nn.Module):
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x = x + mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = x.clip(-65504, 65504)
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue