123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
|
import math
|
||
|
from typing import List, Optional, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from einops import rearrange
|
||
|
from torch import Tensor
|
||
|
|
||
|
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||
|
|
||
|
|
||
|
class ControlNetEmbedder(nn.Module):
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
img_size: int,
|
||
|
patch_size: int,
|
||
|
in_chans: int,
|
||
|
attention_head_dim: int,
|
||
|
num_attention_heads: int,
|
||
|
adm_in_channels: int,
|
||
|
num_layers: int,
|
||
|
main_model_double: int,
|
||
|
double_y_emb: bool,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
pos_embed_max_size: Optional[int] = None,
|
||
|
operations = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.main_model_double = main_model_double
|
||
|
self.dtype = dtype
|
||
|
self.hidden_size = num_attention_heads * attention_head_dim
|
||
|
self.patch_size = patch_size
|
||
|
self.x_embedder = PatchEmbed(
|
||
|
img_size=img_size,
|
||
|
patch_size=patch_size,
|
||
|
in_chans=in_chans,
|
||
|
embed_dim=self.hidden_size,
|
||
|
strict_img_size=pos_embed_max_size is None,
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
operations=operations,
|
||
|
)
|
||
|
|
||
|
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||
|
|
||
|
self.double_y_emb = double_y_emb
|
||
|
if self.double_y_emb:
|
||
|
self.orig_y_embedder = VectorEmbedder(
|
||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||
|
)
|
||
|
self.y_embedder = VectorEmbedder(
|
||
|
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||
|
)
|
||
|
else:
|
||
|
self.y_embedder = VectorEmbedder(
|
||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||
|
)
|
||
|
|
||
|
self.transformer_blocks = nn.ModuleList(
|
||
|
DismantledBlock(
|
||
|
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||
|
dtype=dtype, device=device, operations=operations
|
||
|
)
|
||
|
for _ in range(num_layers)
|
||
|
)
|
||
|
|
||
|
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||
|
# TODO double check this logic when 8b
|
||
|
self.use_y_embedder = True
|
||
|
|
||
|
self.controlnet_blocks = nn.ModuleList([])
|
||
|
for _ in range(len(self.transformer_blocks)):
|
||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||
|
self.controlnet_blocks.append(controlnet_block)
|
||
|
|
||
|
self.pos_embed_input = PatchEmbed(
|
||
|
img_size=img_size,
|
||
|
patch_size=patch_size,
|
||
|
in_chans=in_chans,
|
||
|
embed_dim=self.hidden_size,
|
||
|
strict_img_size=False,
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
operations=operations,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
timesteps: torch.Tensor,
|
||
|
y: Optional[torch.Tensor] = None,
|
||
|
context: Optional[torch.Tensor] = None,
|
||
|
hint = None,
|
||
|
) -> Tuple[Tensor, List[Tensor]]:
|
||
|
x_shape = list(x.shape)
|
||
|
x = self.x_embedder(x)
|
||
|
if not self.double_y_emb:
|
||
|
h = (x_shape[-2] + 1) // self.patch_size
|
||
|
w = (x_shape[-1] + 1) // self.patch_size
|
||
|
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||
|
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||
|
if y is not None and self.y_embedder is not None:
|
||
|
if self.double_y_emb:
|
||
|
y = self.orig_y_embedder(y)
|
||
|
y = self.y_embedder(y)
|
||
|
c = c + y
|
||
|
|
||
|
x = x + self.pos_embed_input(hint)
|
||
|
|
||
|
block_out = ()
|
||
|
|
||
|
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||
|
for i in range(len(self.transformer_blocks)):
|
||
|
out = self.transformer_blocks[i](x, c)
|
||
|
if not self.double_y_emb:
|
||
|
x = out
|
||
|
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||
|
|
||
|
return {"output": block_out}
|