Basic hunyuan dit implementation. (#4102)
* Let tokenizers return weights to be stored in the saved checkpoint. * Basic hunyuan dit implementation. * Fix some resolutions not working. * Support hydit checkpoint save. * Init with right dtype. * Switch to optimized attention in pooler. * Fix black images on hunyuan dit.
This commit is contained in:
parent
f87810cd3e
commit
a5f4292f9f
|
@ -0,0 +1,219 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union, Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting it with another tensor.
|
||||
|
||||
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||
|
||||
Args:
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
||||
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Reshaped frequency tensor.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the frequency tensor doesn't match the expected shape.
|
||||
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
||||
"""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: Optional[torch.Tensor],
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
head_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
||||
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
||||
returned as real tensors.
|
||||
|
||||
Args:
|
||||
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
||||
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
|
||||
"""
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
Use QK Normalization.
|
||||
"""
|
||||
def __init__(self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
self.num_heads = num_heads
|
||||
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||
self.head_dim = self.qdim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, y, freqs_cis_img=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
y: torch.Tensor
|
||||
(batch, seqlen2, hidden_dim2)
|
||||
freqs_cis_img: torch.Tensor
|
||||
(batch, hidden_dim // 2), RoPE for image
|
||||
"""
|
||||
b, s1, c = x.shape # [b, s1, D]
|
||||
_, s2, c = y.shape # [b, s2, 1024]
|
||||
|
||||
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
||||
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
||||
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
||||
q = qq
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
out = self.proj_drop(out)
|
||||
|
||||
out_tuple = (out,)
|
||||
|
||||
return out_tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
We rename some layer names to align with flash attention
|
||||
"""
|
||||
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.head_dim = self.dim // num_heads
|
||||
# This assertion is aligned with flash attention
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
# qkv --> Wqkv
|
||||
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, freqs_cis_img=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
||||
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
||||
q = self.q_norm(q) # [b, h, s, d]
|
||||
k = self.k_norm(k) # [b, h, s, d]
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
||||
assert qq.shape == q.shape and kk.shape == k.shape, \
|
||||
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
||||
q, k = qq, kk
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
x = self.out_proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
out_tuple = (x,)
|
||||
|
||||
return out_tuple
|
|
@ -0,0 +1,402 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from .attn_layers import Attention, CrossAttention
|
||||
from .poolers import AttentionPool
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
def calc_rope(x, patch_size, head_size):
|
||||
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
||||
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
return rope
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
"""
|
||||
A HunYuanDiT block with `add` conditioning.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_type="layer",
|
||||
skip=False,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
use_ele_affine = True
|
||||
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
# ========================= Self-Attention =========================
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= FFN =========================
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= Add =========================
|
||||
# Simply use add like SDXL.
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# ========================= Cross-Attention =========================
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
||||
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
# ========================= Skip Connection =========================
|
||||
if skip:
|
||||
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
# Long Skip Connection
|
||||
if self.skip_linear is not None:
|
||||
cat = torch.cat([x, skip], dim=-1)
|
||||
cat = self.skip_norm(cat)
|
||||
x = self.skip_linear(cat)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
||||
attn_inputs = (
|
||||
self.norm1(x) + shift_msa, freq_cis_img,
|
||||
)
|
||||
x = x + self.attn1(*attn_inputs)[0]
|
||||
|
||||
# Cross-Attention
|
||||
cross_inputs = (
|
||||
self.norm3(x), text_states, freq_cis_img
|
||||
)
|
||||
x = x + self.attn2(*cross_inputs)[0]
|
||||
|
||||
# FFN Layer
|
||||
mlp_inputs = self.norm2(x)
|
||||
x = x + self.mlp(mlp_inputs)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
|
||||
return self._forward(x, c, text_states, freq_cis_img, skip)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of HunYuanDiT.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunYuanDiT(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: argparse.Namespace
|
||||
The arguments parsed by argparse.
|
||||
input_size: tuple
|
||||
The size of the input image.
|
||||
patch_size: int
|
||||
The size of the patch.
|
||||
in_channels: int
|
||||
The number of input channels.
|
||||
hidden_size: int
|
||||
The hidden size of the transformer backbone.
|
||||
depth: int
|
||||
The number of transformer blocks.
|
||||
num_heads: int
|
||||
The number of attention heads.
|
||||
mlp_ratio: float
|
||||
The ratio of the hidden size of the MLP in the transformer block.
|
||||
log_fn: callable
|
||||
The logging function.
|
||||
"""
|
||||
#@register_to_config
|
||||
def __init__(self,
|
||||
input_size: tuple = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
hidden_size: int = 1152,
|
||||
depth: int = 28,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_states_dim = 1024,
|
||||
text_states_dim_t5 = 2048,
|
||||
text_len = 77,
|
||||
text_len_t5 = 256,
|
||||
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
|
||||
size_cond = False,
|
||||
use_style_cond = False,
|
||||
learn_sigma = True,
|
||||
norm = "layer",
|
||||
log_fn: callable = print,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_fn = log_fn
|
||||
self.depth = depth
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.text_states_dim = text_states_dim
|
||||
self.text_states_dim_t5 = text_states_dim_t5
|
||||
self.text_len = text_len
|
||||
self.text_len_t5 = text_len_t5
|
||||
self.size_cond = size_cond
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
# learnable replace
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
|
||||
|
||||
# Attention pooling
|
||||
pooler_out_dim = 1024
|
||||
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# Dimension of the extra input vectors
|
||||
self.extra_in_dim = pooler_out_dim
|
||||
|
||||
if self.size_cond:
|
||||
# Image size and crop size conditions
|
||||
self.extra_in_dim += 6 * 256
|
||||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=self.text_states_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_type=self.norm,
|
||||
skip=layer > depth // 2,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.unpatchify_channels = self.out_channels
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(B, D, H, W)
|
||||
t: torch.Tensor
|
||||
(B)
|
||||
encoder_hidden_states: torch.Tensor
|
||||
CLIP text embedding, (B, L_clip, D)
|
||||
text_embedding_mask: torch.Tensor
|
||||
CLIP text embedding mask, (B, L_clip)
|
||||
encoder_hidden_states_t5: torch.Tensor
|
||||
T5 text embedding, (B, L_t5, D)
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
T5 text embedding mask, (B, L_t5)
|
||||
image_meta_size: torch.Tensor
|
||||
(B, 6)
|
||||
style: torch.Tensor
|
||||
(B)
|
||||
cos_cis_img: torch.Tensor
|
||||
sin_cis_img: torch.Tensor
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
|
||||
text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
|
||||
|
||||
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
||||
|
||||
clip_t5_mask = clip_t5_mask
|
||||
text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
|
||||
|
||||
_, _, oh, ow = x.shape
|
||||
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
|
||||
|
||||
|
||||
# Get image RoPE embedding according to `reso`lution.
|
||||
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(t, dtype=x.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
# Build text tokens with pooling
|
||||
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||
|
||||
# Build image meta size tokens if applicable
|
||||
if self.size_cond:
|
||||
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||
|
||||
# Build style tokens
|
||||
if self.use_style_cond:
|
||||
if style is None:
|
||||
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
||||
style_embedding = self.style_embedder(style)
|
||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
controls = None
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
if controls is not None and len(controls) != 0:
|
||||
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
||||
|
||||
# ========================= Final layer =========================
|
||||
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
|
||||
|
||||
if return_dict:
|
||||
return {'x': x}
|
||||
if self.learn_sigma:
|
||||
return x[:,:self.out_channels // 2,:oh,:ow]
|
||||
return x[:,:,:oh,:ow]
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.unpatchify_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
# h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention #TODO
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
batch_size = q.shape[1]
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
|
||||
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
return attn_output.squeeze(0)
|
|
@ -0,0 +1,224 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
def _to_tuple(x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(src, tgt):
|
||||
th, tw = _to_tuple(tgt)
|
||||
h, w = _to_tuple(src)
|
||||
|
||||
tr = th / tw # base resolution
|
||||
r = h / w # target resolution
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = _to_tuple(args[1])
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
# grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
# grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
# grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (W,H)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Rotary Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
||||
"""
|
||||
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
embed_dim: int
|
||||
embedding dimension size
|
||||
start: int or tuple of int
|
||||
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
||||
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
use_real: bool
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_embed: torch.Tensor
|
||||
[HW, D/2]
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
|
||||
def calc_sizes(rope_img, patch_size, th, tw):
|
||||
if rope_img == 'extend':
|
||||
# Expansion mode
|
||||
sub_args = [(th, tw)]
|
||||
elif rope_img.startswith('base'):
|
||||
# Based on the specified dimensions, other dimensions are obtained through interpolation.
|
||||
base_size = int(rope_img[4:]) // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
else:
|
||||
raise ValueError(f"Unknown rope_img: {rope_img}")
|
||||
return sub_args
|
||||
|
||||
|
||||
def init_image_posemb(rope_img,
|
||||
resolutions,
|
||||
patch_size,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
log_fn,
|
||||
rope_real=True,
|
||||
):
|
||||
freqs_cis_img = {}
|
||||
for reso in resolutions:
|
||||
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
||||
sub_args = calc_sizes(rope_img, patch_size, th, tw)
|
||||
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
||||
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
||||
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
||||
return freqs_cis_img
|
|
@ -7,6 +7,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.model_management
|
||||
|
@ -648,3 +649,35 @@ class StableAudio1(BaseModel):
|
|||
for l in s:
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
|
||||
if conditioning_mt5xl is not None:
|
||||
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
|
||||
|
||||
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
|
||||
if attention_mask_mt5xl is not None:
|
||||
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
|
||||
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
|
|
@ -115,6 +115,19 @@ def detect_unet_config(state_dict, key_prefix):
|
|||
unet_config["n_layers"] = double_layers + single_layers
|
||||
return unet_config
|
||||
|
||||
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "hydit"
|
||||
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
||||
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
|
||||
unet_config["mlp_ratio"] = 4.3637
|
||||
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
||||
unet_config["size_cond"] = True
|
||||
unet_config["use_style_cond"] = True
|
||||
unet_config["image_model"] = "hydit1"
|
||||
return unet_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from . import sdxl_clip
|
|||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
|
@ -385,6 +386,7 @@ class CLIPType(Enum):
|
|||
STABLE_CASCADE = 2
|
||||
SD3 = 3
|
||||
STABLE_AUDIO = 4
|
||||
HUNYUAN_DIT = 5
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
|
@ -433,6 +435,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
|
|
@ -8,6 +8,7 @@ from . import sdxl_clip
|
|||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
@ -580,6 +581,45 @@ class AuraFlow(supported_models_base.BASE):
|
|||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
|
||||
class HunyuanDiT(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hydit",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.018,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanDiT(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
|
||||
|
||||
class HunyuanDiT1(HunyuanDiT):
|
||||
unet_config = {
|
||||
"image_model": "hydit1",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start" : 0.00085,
|
||||
"linear_end" : 0.03,
|
||||
}
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_comfy_ops(ops, device=None, dtype=None):
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
force_device = device
|
||||
force_dtype = dtype
|
||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
if force_device is not None:
|
||||
device = force_device
|
||||
if force_dtype is not None:
|
||||
dtype = force_dtype
|
||||
return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
torch.nn.Linear = linear_with_dtype
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Linear = old_torch_nn_linear
|
||||
|
||||
|
||||
class RobertaWrapper(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = BertConfig(**config_dict)
|
||||
with use_comfy_ops(operations, device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.bert = BertModel(config, add_pooling_layer=False)
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.bert.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
return self.bert.set_input_embeddings(value)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask)
|
||||
return out.last_hidden_state, intermediate, out.pooler_output
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer)
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class HyditTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.hydit_clip.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel()
|
||||
self.mt5xl = MT5XLModel()
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
|
||||
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
|
||||
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
|
||||
return self.hydit_clip.load_sd(sd)
|
||||
else:
|
||||
return self.mt5xl.load_sd(sd)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.hydit_clip.set_clip_options(options)
|
||||
self.mt5xl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.hydit_clip.reset_clip_options()
|
||||
self.mt5xl.reset_clip_options()
|
|
@ -0,0 +1,35 @@
|
|||
{
|
||||
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classifier_dropout": null,
|
||||
"directionality": "bidi",
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"pooler_fc_size": 768,
|
||||
"pooler_num_attention_heads": 12,
|
||||
"pooler_num_fc_layers": 3,
|
||||
"pooler_size_per_head": 128,
|
||||
"pooler_type": "first_token_transform",
|
||||
"position_embedding_type": "absolute",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.22.1",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 47020
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]",
|
||||
"model_max_length": 77
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "mt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 250112
|
||||
}
|
|
@ -27,3 +27,6 @@ class SPieceTokenizer:
|
|||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
return {"input_ids": out}
|
||||
|
||||
def serialize_model(self):
|
||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||
|
|
Loading…
Reference in New Issue