Basic Genmo Mochi video model support.
To use: "Load CLIP" node with t5xxl + type mochi "Load Diffusion Model" node with the mochi dit file. "Load VAE" with the mochi vae file. EmptyMochiLatentVideo node for the latent. euler + linear_quadratic in the KSampler node.
This commit is contained in:
parent
c3ffbae067
commit
5cbb01bc2f
|
@ -175,3 +175,30 @@ class Flux(SD3):
|
||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return (latent / self.scale_factor) + self.shift_factor
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class Mochi(LatentFormat):
|
||||||
|
latent_channels = 12
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||||
|
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||||
|
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||||
|
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||||
|
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||||
|
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||||
|
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
self.latent_rgb_factors = None #TODO
|
||||||
|
self.taesd_decoder_name = None #TODO
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|
|
@ -13,9 +13,15 @@ try:
|
||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
|
||||||
def rms_norm(x, weight, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
else:
|
else:
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
|
@ -0,0 +1,541 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
FeedForward,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
TimestepEmbedder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .rope_mixed import (
|
||||||
|
compute_mixed_rotation,
|
||||||
|
create_position_matrix,
|
||||||
|
)
|
||||||
|
from .temporal_rope import apply_rotary_emb_qk_real
|
||||||
|
from .utils import (
|
||||||
|
AttentionPool,
|
||||||
|
modulate,
|
||||||
|
)
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
|
# Normalize and modulate
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||||
|
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
|
return x_modulated
|
||||||
|
|
||||||
|
|
||||||
|
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||||
|
# Apply tanh to gate
|
||||||
|
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||||
|
|
||||||
|
# Normalize and apply gated scaling
|
||||||
|
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||||
|
|
||||||
|
# Apply residual connection
|
||||||
|
output = x + x_normed
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
class AsymmetricAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_x: int,
|
||||||
|
dim_y: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
update_y: bool = True,
|
||||||
|
out_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_x = dim_x
|
||||||
|
self.dim_y = dim_y
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim_x // num_heads
|
||||||
|
self.attn_drop = attn_drop
|
||||||
|
self.update_y = update_y
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.softmax_scale = softmax_scale
|
||||||
|
if dim_x % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input layers.
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
# Project text features to match visual features (dim_y -> dim_x)
|
||||||
|
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Query and key normalization for stability.
|
||||||
|
assert qk_norm
|
||||||
|
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Output layers. y features go back down from dim_x -> dim_y.
|
||||||
|
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
self.proj_y = (
|
||||||
|
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
if update_y
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
|
y: torch.Tensor, # (B, L, dim_y)
|
||||||
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
|
crop_y,
|
||||||
|
**rope_rotation,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
|
rope_sin = rope_rotation.get("rope_sin")
|
||||||
|
# Pre-norm for visual features
|
||||||
|
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||||
|
|
||||||
|
# Process visual features
|
||||||
|
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||||
|
# assert qkv_x.dtype == torch.bfloat16
|
||||||
|
# qkv_x = all_to_all_collect_tokens(
|
||||||
|
# qkv_x, self.num_heads
|
||||||
|
# ) # (3, B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
# Process text features
|
||||||
|
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||||
|
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
|
||||||
|
q_y = self.q_norm_y(q_y)
|
||||||
|
k_y = self.k_norm_y(k_y)
|
||||||
|
|
||||||
|
# Split qkv_x into q, k, v
|
||||||
|
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
q_x = self.q_norm_x(q_x)
|
||||||
|
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||||
|
k_x = self.k_norm_x(k_x)
|
||||||
|
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||||
|
|
||||||
|
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||||
|
|
||||||
|
xy = optimized_attention(q,
|
||||||
|
k,
|
||||||
|
v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||||
|
o[:, :y.shape[1]] = y
|
||||||
|
|
||||||
|
y = self.proj_y(o)
|
||||||
|
# print("ox", x)
|
||||||
|
# print("oy", y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmetricJointBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size_x: int,
|
||||||
|
hidden_size_y: int,
|
||||||
|
num_heads: int,
|
||||||
|
*,
|
||||||
|
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||||
|
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||||
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.update_y = update_y
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||||
|
if self.update_y:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Self-attention:
|
||||||
|
self.attn = AsymmetricAttention(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads=num_heads,
|
||||||
|
update_y=update_y,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||||
|
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||||
|
self.mlp_x = FeedForward(
|
||||||
|
in_features=hidden_size_x,
|
||||||
|
hidden_size=mlp_hidden_dim_x,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MLP for text not needed in last block.
|
||||||
|
if self.update_y:
|
||||||
|
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||||
|
self.mlp_y = FeedForward(
|
||||||
|
in_features=hidden_size_y,
|
||||||
|
hidden_size=mlp_hidden_dim_y,
|
||||||
|
multiple_of=256,
|
||||||
|
ffn_dim_multiplier=None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
**attn_kwargs,
|
||||||
|
):
|
||||||
|
"""Forward pass of a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, N, dim) tensor of visual tokens
|
||||||
|
c: (B, dim) tensor of conditioned features
|
||||||
|
y: (B, L, dim) tensor of text tokens
|
||||||
|
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, N, dim) tensor of visual tokens after block
|
||||||
|
y: (B, L, dim) tensor of text tokens after block
|
||||||
|
"""
|
||||||
|
N = x.size(1)
|
||||||
|
|
||||||
|
c = F.silu(c)
|
||||||
|
mod_x = self.mod_x(c)
|
||||||
|
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||||
|
|
||||||
|
mod_y = self.mod_y(c)
|
||||||
|
if self.update_y:
|
||||||
|
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||||
|
else:
|
||||||
|
scale_msa_y = mod_y
|
||||||
|
|
||||||
|
# Self-attention block.
|
||||||
|
x_attn, y_attn = self.attn(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
scale_x=scale_msa_x,
|
||||||
|
scale_y=scale_msa_y,
|
||||||
|
**attn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert x_attn.size(1) == N
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||||
|
|
||||||
|
# MLP block.
|
||||||
|
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||||
|
if self.update_y:
|
||||||
|
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def ff_block_x(self, x, scale_x, gate_x):
|
||||||
|
x_mod = modulated_rmsnorm(x, scale_x)
|
||||||
|
x_res = self.mlp_x(x_mod)
|
||||||
|
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||||
|
return x
|
||||||
|
|
||||||
|
def ff_block_y(self, y, scale_y, gate_y):
|
||||||
|
y_mod = modulated_rmsnorm(y, scale_y)
|
||||||
|
y_res = self.mlp_y(y_mod)
|
||||||
|
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
patch_size,
|
||||||
|
out_channels,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
c = F.silu(c)
|
||||||
|
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmDiTJoint(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Ingests text embeddings instead of a label.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size_x=1152,
|
||||||
|
hidden_size_y=1152,
|
||||||
|
depth=48,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio_x=8.0,
|
||||||
|
mlp_ratio_y=4.0,
|
||||||
|
use_t5: bool = False,
|
||||||
|
t5_feat_dim: int = 4096,
|
||||||
|
t5_token_length: int = 256,
|
||||||
|
learn_sigma=True,
|
||||||
|
patch_embed_bias: bool = True,
|
||||||
|
timestep_mlp_bias: bool = True,
|
||||||
|
attend_to_padding: bool = False,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
use_extended_posenc: bool = False,
|
||||||
|
posenc_preserve_area: bool = False,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
image_model=None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size_x = hidden_size_x
|
||||||
|
self.hidden_size_y = hidden_size_y
|
||||||
|
self.head_dim = (
|
||||||
|
hidden_size_x // num_heads
|
||||||
|
) # Head dimension and count is determined by visual.
|
||||||
|
self.attend_to_padding = attend_to_padding
|
||||||
|
self.use_extended_posenc = use_extended_posenc
|
||||||
|
self.posenc_preserve_area = posenc_preserve_area
|
||||||
|
self.use_t5 = use_t5
|
||||||
|
self.t5_token_length = t5_token_length
|
||||||
|
self.t5_feat_dim = t5_feat_dim
|
||||||
|
self.rope_theta = (
|
||||||
|
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_channels,
|
||||||
|
embed_dim=hidden_size_x,
|
||||||
|
bias=patch_embed_bias,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
# Conditionings
|
||||||
|
# Timestep
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_t5:
|
||||||
|
# Caption Pooling (T5)
|
||||||
|
self.t5_y_embedder = AttentionPool(
|
||||||
|
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dense Embedding Projection (T5)
|
||||||
|
self.t5_yproj = operations.Linear(
|
||||||
|
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize pos_frequencies as an empty parameter.
|
||||||
|
self.pos_frequencies = nn.Parameter(
|
||||||
|
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not self.attend_to_padding
|
||||||
|
|
||||||
|
# for depth 48:
|
||||||
|
# b = 0: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 1: AsymmetricJointBlock, update_y=True
|
||||||
|
# ...
|
||||||
|
# b = 46: AsymmetricJointBlock, update_y=True
|
||||||
|
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||||
|
blocks = []
|
||||||
|
for b in range(depth):
|
||||||
|
# Joint multi-modal block
|
||||||
|
update_y = b < depth - 1
|
||||||
|
block = AsymmetricJointBlock(
|
||||||
|
hidden_size_x,
|
||||||
|
hidden_size_y,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio_x=mlp_ratio_x,
|
||||||
|
mlp_ratio_y=mlp_ratio_y,
|
||||||
|
update_y=update_y,
|
||||||
|
attend_to_padding=attend_to_padding,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(block)
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||||
|
"""
|
||||||
|
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma: torch.Tensor,
|
||||||
|
t5_feat: torch.Tensor,
|
||||||
|
t5_mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Prepare input and conditioning embeddings."""
|
||||||
|
# Visual patch embeddings with positional encoding.
|
||||||
|
T, H, W = x.shape[-3:]
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
|
assert x.ndim == 3
|
||||||
|
B = x.size(0)
|
||||||
|
|
||||||
|
|
||||||
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
|
N = T * pH * pW
|
||||||
|
assert x.size(1) == N
|
||||||
|
pos = create_position_matrix(
|
||||||
|
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||||
|
) # (N, 3)
|
||||||
|
rope_cos, rope_sin = compute_mixed_rotation(
|
||||||
|
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||||
|
) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
|
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||||
|
|
||||||
|
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||||
|
|
||||||
|
c = c_t + t5_y_pool
|
||||||
|
|
||||||
|
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||||
|
|
||||||
|
return x, c, y_feat, rope_cos, rope_sin
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: List[torch.Tensor],
|
||||||
|
attention_mask: List[torch.Tensor],
|
||||||
|
num_tokens=256,
|
||||||
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
|
rope_cos: torch.Tensor = None,
|
||||||
|
rope_sin: torch.Tensor = None,
|
||||||
|
control=None, **kwargs
|
||||||
|
):
|
||||||
|
y_feat = context
|
||||||
|
y_mask = attention_mask
|
||||||
|
sigma = timestep
|
||||||
|
"""Forward pass of DiT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
sigma: (B,) tensor of noise standard deviations
|
||||||
|
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||||
|
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||||
|
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||||
|
x, sigma, y_feat, y_mask
|
||||||
|
)
|
||||||
|
del y_mask
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
x, y_feat = block(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
y_feat,
|
||||||
|
rope_cos=rope_cos,
|
||||||
|
rope_sin=rope_sin,
|
||||||
|
crop_y=num_tokens,
|
||||||
|
) # (B, M, D), (B, L, D)
|
||||||
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||||
|
T=T,
|
||||||
|
hp=H // self.patch_size,
|
||||||
|
wp=W // self.patch_size,
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
c=self.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return -x
|
|
@ -0,0 +1,164 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import math
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
frequency_embedding_size: int = 256,
|
||||||
|
*,
|
||||||
|
bias: bool = True,
|
||||||
|
timestep_scale: Optional[float] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.timestep_scale = timestep_scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
|
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, out_dtype):
|
||||||
|
if self.timestep_scale is not None:
|
||||||
|
t = t * self.timestep_scale
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_size: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# keep parameter count and computation constant compared to standard FFN
|
||||||
|
hidden_size = int(2 * hidden_size / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||||
|
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||||
|
x = self.w2(F.silu(x) * gate)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer: Optional[Callable] = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = to_2tuple(patch_size)
|
||||||
|
self.flatten = flatten
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
|
self.proj = operations.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
assert norm_layer is None
|
||||||
|
self.norm = (
|
||||||
|
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, _C, T, H, W = x.shape
|
||||||
|
if not self.dynamic_img_pad:
|
||||||
|
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||||
|
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||||
|
else:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Flatten temporal and spatial dimensions.
|
||||||
|
if not self.flatten:
|
||||||
|
raise NotImplementedError("Must flatten output.")
|
||||||
|
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
|
@ -0,0 +1,88 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# import functools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def centers(start: float, stop, num, dtype=None, device=None):
|
||||||
|
"""linspace through bin centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (float): Start of the range.
|
||||||
|
stop (float): End of the range.
|
||||||
|
num (int): Number of points.
|
||||||
|
dtype (torch.dtype): Data type of the points.
|
||||||
|
device (torch.device): Device of the points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||||
|
"""
|
||||||
|
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||||
|
return (edges[:-1] + edges[1:]) / 2
|
||||||
|
|
||||||
|
|
||||||
|
# @functools.lru_cache(maxsize=1)
|
||||||
|
def create_position_matrix(
|
||||||
|
T: int,
|
||||||
|
pH: int,
|
||||||
|
pW: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
target_area: float = 36864,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
T: int - Temporal dimension
|
||||||
|
pH: int - Height dimension after patchify
|
||||||
|
pW: int - Width dimension after patchify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pos: [T * pH * pW, 3] - position matrix
|
||||||
|
"""
|
||||||
|
# Create 1D tensors for each dimension
|
||||||
|
t = torch.arange(T, dtype=dtype)
|
||||||
|
|
||||||
|
# Positionally interpolate to area 36864.
|
||||||
|
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||||
|
# This automatically scales rope positions when the resolution changes.
|
||||||
|
# We use a large target area so the model is more sensitive
|
||||||
|
# to changes in the learned pos_frequencies matrix.
|
||||||
|
scale = math.sqrt(target_area / (pW * pH))
|
||||||
|
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||||
|
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||||
|
|
||||||
|
# Use meshgrid to create 3D grids
|
||||||
|
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||||
|
|
||||||
|
# Stack and reshape the grids.
|
||||||
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||||
|
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||||
|
pos = pos.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mixed_rotation(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
pos: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||||
|
pos: [N, 3] - position of each token
|
||||||
|
num_heads: int
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||||
|
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||||
|
"""
|
||||||
|
assert freqs.ndim == 3
|
||||||
|
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||||
|
freqs_cos = torch.cos(freqs_sum)
|
||||||
|
freqs_sin = torch.sin(freqs_sum)
|
||||||
|
return freqs_cos, freqs_sin
|
|
@ -0,0 +1,34 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
|
||||||
|
# Based on Llama3 Implementation.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_qk_real(
|
||||||
|
xqk: torch.Tensor,
|
||||||
|
freqs_cos: torch.Tensor,
|
||||||
|
freqs_sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||||
|
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||||
|
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||||
|
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||||
|
"""
|
||||||
|
# Split the last dimension into even and odd parts
|
||||||
|
xqk_even = xqk[..., 0::2]
|
||||||
|
xqk_odd = xqk[..., 1::2]
|
||||||
|
|
||||||
|
# Apply rotation
|
||||||
|
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||||
|
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||||
|
|
||||||
|
# Interleave the results back into the original shape
|
||||||
|
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||||
|
return out
|
|
@ -0,0 +1,102 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pool tokens in x using mask.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, L, D) tensor of tokens.
|
||||||
|
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pooled: (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||||
|
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||||
|
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||||
|
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
|
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
spatial_dim (int): Number of tokens in sequence length.
|
||||||
|
embed_dim (int): Dimensionality of input tokens.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||||
|
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
D = x.size(2)
|
||||||
|
|
||||||
|
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||||
|
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||||
|
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||||
|
|
||||||
|
# Average non-padding token features. These will be used as the query.
|
||||||
|
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||||
|
|
||||||
|
# Concat pooled features to input sequence.
|
||||||
|
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||||
|
|
||||||
|
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||||
|
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||||
|
q = self.to_q(x[:, 0]) # (B, D)
|
||||||
|
|
||||||
|
# Extract heads.
|
||||||
|
head_dim = D // self.num_heads
|
||||||
|
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||||
|
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||||
|
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||||
|
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||||
|
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Compute attention.
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||||
|
) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Concatenate heads and run output.
|
||||||
|
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||||
|
x = self.to_out(x)
|
||||||
|
return x
|
|
@ -0,0 +1,480 @@
|
||||||
|
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
|
#adapted to ComfyUI
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||||
|
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t, length=1):
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormSpatial(ops.GroupNorm):
|
||||||
|
"""
|
||||||
|
GroupNorm applied per-frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||||
|
# Run group norm in chunks.
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
for b in range(0, B * T, chunk_size):
|
||||||
|
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||||
|
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||||
|
|
||||||
|
class PConv3d(ops.Conv3d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
|
causal: bool = True,
|
||||||
|
context_parallel: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.causal = causal
|
||||||
|
self.context_parallel = context_parallel
|
||||||
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
stride = cast_tuple(stride, 3)
|
||||||
|
height_pad = (kernel_size[1] - 1) // 2
|
||||||
|
width_pad = (kernel_size[2] - 1) // 2
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=(1, 1, 1),
|
||||||
|
padding=(0, height_pad, width_pad),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# Compute padding amounts.
|
||||||
|
context_size = self.kernel_size[0] - 1
|
||||||
|
if self.causal:
|
||||||
|
pad_front = context_size
|
||||||
|
pad_back = 0
|
||||||
|
else:
|
||||||
|
pad_front = context_size // 2
|
||||||
|
pad_back = context_size - pad_front
|
||||||
|
|
||||||
|
# Apply padding.
|
||||||
|
assert self.padding_mode == "replicate" # DEBUG
|
||||||
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1x1(ops.Linear):
|
||||||
|
"""*1x1 Conv implemented with a linear layer."""
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||||
|
super().__init__(in_features, out_features, *args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||||
|
"""
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
x = super().forward(x)
|
||||||
|
x = x.movedim(-1, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpaceTime(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temporal_expansion: int,
|
||||||
|
spatial_expansion: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# When printed, this module should show the temporal and spatial expansion factors.
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||||
|
"""
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||||
|
st=self.temporal_expansion,
|
||||||
|
sh=self.spatial_expansion,
|
||||||
|
sw=self.spatial_expansion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cp_rank, _ = cp.get_cp_rank_size()
|
||||||
|
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||||
|
# Drop the first self.temporal_expansion - 1 frames.
|
||||||
|
# This is because we always want the 3x3x3 conv filter to only apply
|
||||||
|
# to the first frame, and the first frame doesn't need to be repeated.
|
||||||
|
assert all(x.shape)
|
||||||
|
x = x[:, :, self.temporal_expansion - 1 :]
|
||||||
|
assert all(x.shape)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm_fn(
|
||||||
|
in_channels: int,
|
||||||
|
affine: bool = True,
|
||||||
|
):
|
||||||
|
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""Residual block that preserves the spatial dimensions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
*,
|
||||||
|
affine: bool = True,
|
||||||
|
attn_block: Optional[nn.Module] = None,
|
||||||
|
padding_mode: str = "replicate",
|
||||||
|
causal: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
assert causal
|
||||||
|
self.stack = nn.Sequential(
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=True,
|
||||||
|
# causal=causal,
|
||||||
|
),
|
||||||
|
norm_fn(channels, affine=affine),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
PConv3d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=(1, 1, 1),
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
bias=True,
|
||||||
|
# causal=causal,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x = self.stack(x)
|
||||||
|
x = x + residual
|
||||||
|
del residual
|
||||||
|
|
||||||
|
return self.attn_block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalUpsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
*,
|
||||||
|
temporal_expansion: int = 2,
|
||||||
|
spatial_expansion: int = 2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||||
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
self.temporal_expansion = temporal_expansion
|
||||||
|
self.spatial_expansion = spatial_expansion
|
||||||
|
|
||||||
|
# Change channels in the final convolution layer.
|
||||||
|
self.proj = Conv1x1(
|
||||||
|
in_channels,
|
||||||
|
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.d2st = DepthToSpaceTime(
|
||||||
|
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.d2st(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||||
|
assert has_attention is False #NOTE: if this is ever true add back the attention code.
|
||||||
|
|
||||||
|
attn_block = None #AttentionBlock(channels) if has_attention else None
|
||||||
|
|
||||||
|
return ResBlock(
|
||||||
|
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks,
|
||||||
|
*,
|
||||||
|
temporal_reduction=2,
|
||||||
|
spatial_reduction=2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downsample block for the VAE encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels.
|
||||||
|
out_channels: Number of output channels.
|
||||||
|
num_res_blocks: Number of residual blocks.
|
||||||
|
temporal_reduction: Temporal reduction factor.
|
||||||
|
spatial_reduction: Spatial reduction factor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
# Change the channel count in the strided convolution.
|
||||||
|
# This lets the ResBlock have uniform channel count,
|
||||||
|
# as in ConvNeXt.
|
||||||
|
assert in_channels != out_channels
|
||||||
|
layers.append(
|
||||||
|
PConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
padding_mode="replicate",
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers.append(block_fn(out_channels, **block_kwargs))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||||
|
num_freqs = (stop - start) // step
|
||||||
|
assert inputs.ndim == 5
|
||||||
|
C = inputs.size(1)
|
||||||
|
|
||||||
|
# Create Base 2 Fourier features.
|
||||||
|
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||||
|
assert num_freqs == len(freqs)
|
||||||
|
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||||
|
C = inputs.shape[1]
|
||||||
|
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||||
|
|
||||||
|
# Interleaved repeat of input channels to match w.
|
||||||
|
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||||
|
# Scale channels by frequency.
|
||||||
|
h = w * h
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
inputs,
|
||||||
|
torch.sin(h),
|
||||||
|
torch.cos(h),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Add Fourier features to inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||||
|
"""
|
||||||
|
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
out_channels: int = 3,
|
||||||
|
latent_dim: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
temporal_expansions: Optional[List[int]] = None,
|
||||||
|
spatial_expansions: Optional[List[int]] = None,
|
||||||
|
has_attention: List[bool],
|
||||||
|
output_norm: bool = True,
|
||||||
|
nonlinearity: str = "silu",
|
||||||
|
output_nonlinearity: str = "silu",
|
||||||
|
causal: bool = True,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_channels = latent_dim
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.output_nonlinearity = output_nonlinearity
|
||||||
|
assert nonlinearity == "silu"
|
||||||
|
assert causal
|
||||||
|
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
self.num_up_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
first_block = [
|
||||||
|
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||||
|
] # Input layer.
|
||||||
|
# First set of blocks preserve channel count.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
first_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[-1],
|
||||||
|
has_attention=has_attention[-1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*first_block))
|
||||||
|
|
||||||
|
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||||
|
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||||
|
|
||||||
|
upsample_block_fn = CausalUpsampleBlock
|
||||||
|
|
||||||
|
for i in range(self.num_up_blocks):
|
||||||
|
block = upsample_block_fn(
|
||||||
|
ch[-i - 1],
|
||||||
|
ch[-i - 2],
|
||||||
|
num_res_blocks=num_res_blocks[-i - 2],
|
||||||
|
has_attention=has_attention[-i - 2],
|
||||||
|
temporal_expansion=temporal_expansions[-i - 1],
|
||||||
|
spatial_expansion=spatial_expansions[-i - 1],
|
||||||
|
causal=causal,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
assert not output_norm
|
||||||
|
|
||||||
|
# Last block. Preserve channel count.
|
||||||
|
last_block = []
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
last_block.append(
|
||||||
|
block_fn(
|
||||||
|
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
blocks.append(nn.Sequential(*last_block))
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||||
|
T + 1 = (t - 1) * 4.
|
||||||
|
H = h * 16, W = w * 16.
|
||||||
|
"""
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.output_nonlinearity == "silu":
|
||||||
|
x = F.silu(x, inplace=not self.training)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not self.output_nonlinearity
|
||||||
|
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||||
|
|
||||||
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoVAE(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = None #TODO once the model releases
|
||||||
|
self.decoder = Decoder(
|
||||||
|
out_channels=3,
|
||||||
|
base_channels=128,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
temporal_expansions=[1, 2, 3],
|
||||||
|
spatial_expansions=[2, 2, 2],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
has_attention=[False, False, False, False, False],
|
||||||
|
padding_mode="replicate",
|
||||||
|
output_norm=False,
|
||||||
|
nonlinearity="silu",
|
||||||
|
output_nonlinearity="silu",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(x)
|
|
@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
# B, C, H, W = x.shape
|
||||||
# if self.img_size is not None:
|
# if self.img_size is not None:
|
||||||
# if self.strict_img_size:
|
# if self.strict_img_size:
|
||||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
|
|
|
@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB
|
||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||||
import comfy.ldm.aura.mmdit
|
import comfy.ldm.aura.mmdit
|
||||||
import comfy.ldm.hydit.models
|
import comfy.ldm.hydit.models
|
||||||
import comfy.ldm.audio.dit
|
import comfy.ldm.audio.dit
|
||||||
|
@ -718,3 +719,18 @@ class Flux(BaseModel):
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
|
@ -145,6 +145,34 @@ def detect_unet_config(state_dict, key_prefix):
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "mochi_preview"
|
||||||
|
dit_config["depth"] = 48
|
||||||
|
dit_config["patch_size"] = 2
|
||||||
|
dit_config["num_heads"] = 24
|
||||||
|
dit_config["hidden_size_x"] = 3072
|
||||||
|
dit_config["hidden_size_y"] = 1536
|
||||||
|
dit_config["mlp_ratio_x"] = 4.0
|
||||||
|
dit_config["mlp_ratio_y"] = 4.0
|
||||||
|
dit_config["learn_sigma"] = False
|
||||||
|
dit_config["in_channels"] = 12
|
||||||
|
dit_config["qk_norm"] = True
|
||||||
|
dit_config["qkv_bias"] = False
|
||||||
|
dit_config["out_bias"] = True
|
||||||
|
dit_config["attn_drop"] = 0.0
|
||||||
|
dit_config["patch_embed_bias"] = True
|
||||||
|
dit_config["posenc_preserve_area"] = True
|
||||||
|
dit_config["timestep_mlp_bias"] = True
|
||||||
|
dit_config["attend_to_padding"] = False
|
||||||
|
dit_config["timestep_scale"] = 1000.0
|
||||||
|
dit_config["use_t5"] = True
|
||||||
|
dit_config["t5_feat_dim"] = 4096
|
||||||
|
dit_config["t5_token_length"] = 256
|
||||||
|
dit_config["rope_theta"] = 10000.0
|
||||||
|
return dit_config
|
||||||
|
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -366,6 +366,27 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||||
|
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||||
|
if steps == 1:
|
||||||
|
sigma_schedule = [1.0, 0.0]
|
||||||
|
else:
|
||||||
|
if linear_steps is None:
|
||||||
|
linear_steps = steps // 2
|
||||||
|
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||||
|
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||||
|
quadratic_steps = steps - linear_steps
|
||||||
|
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||||
|
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||||
|
const = quadratic_coef * (linear_steps ** 2)
|
||||||
|
quadratic_sigma_schedule = [
|
||||||
|
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||||
|
for i in range(linear_steps, steps)
|
||||||
|
]
|
||||||
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||||
|
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||||
|
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
|
@ -732,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
|
@ -750,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
elif scheduler_name == "beta":
|
elif scheduler_name == "beta":
|
||||||
sigmas = beta_scheduler(model_sampling, steps)
|
sigmas = beta_scheduler(model_sampling, steps)
|
||||||
|
elif scheduler_name == "linear_quadratic":
|
||||||
|
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||||
else:
|
else:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
36
comfy/sd.py
36
comfy/sd.py
|
@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
|
import comfy.ldm.genmo.vae.model
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
@ -25,6 +26,7 @@ import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.long_clipl
|
import comfy.text_encoders.long_clipl
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
|
@ -241,6 +243,13 @@ class VAE:
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
|
||||||
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
|
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||||
|
self.latent_channels = 12
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
|
@ -296,6 +305,10 @@ class VAE:
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||||
|
|
||||||
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
|
@ -314,6 +327,7 @@ class VAE:
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
|
@ -321,16 +335,21 @@ class VAE:
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
if len(samples_in.shape) == 3:
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
else:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
elif dims == 3:
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
@ -398,6 +417,7 @@ class CLIPType(Enum):
|
||||||
STABLE_AUDIO = 4
|
STABLE_AUDIO = 4
|
||||||
HUNYUAN_DIT = 5
|
HUNYUAN_DIT = 5
|
||||||
FLUX = 6
|
FLUX = 6
|
||||||
|
MOCHI = 7
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
|
@ -474,8 +494,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif te_model == TEModel.T5_XXL:
|
elif te_model == TEModel.T5_XXL:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
else: #CLIPType.MOCHI
|
||||||
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
elif te_model == TEModel.T5_XL:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
|
|
|
@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5
|
||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
|
import comfy.text_encoders.genmo
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
@ -670,7 +671,36 @@ class FluxSchnell(Flux):
|
||||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class GenmoMochi(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "mochi_preview",
|
||||||
|
}
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Mochi
|
||||||
|
|
||||||
|
memory_usage_factor = 2.0 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.GenmoMochi(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs["attention_mask"] = True
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class MochiTEModel_(MochiT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return MochiTEModel_
|
|
@ -731,7 +731,27 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
|
||||||
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
|
upscale_amount = [upscale_amount] * dims
|
||||||
|
|
||||||
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
def get_upscale(dim, val):
|
||||||
|
up = upscale_amount[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def mult_list_upscale(a):
|
||||||
|
out = []
|
||||||
|
for i in range(len(a)):
|
||||||
|
out.append(round(get_upscale(i, a[i])))
|
||||||
|
return out
|
||||||
|
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
|
@ -743,27 +763,27 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
for it in itertools.product(*positions):
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
for d in range(dims):
|
for d in range(dims):
|
||||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(get_upscale(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
|
||||||
|
|
||||||
for t in range(feather):
|
for d in range(2, dims + 2):
|
||||||
for d in range(2, dims + 2):
|
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||||
|
for t in range(feather):
|
||||||
a = (t + 1) / feather
|
a = (t + 1) / feather
|
||||||
mask.narrow(d, t, 1).mul_(a)
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class EmptyMochiLatentVideo:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/mochi"
|
||||||
|
|
||||||
|
def generate(self, width, height, length, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||||
|
}
|
10
nodes.py
10
nodes.py
|
@ -281,7 +281,10 @@ class VAEDecode:
|
||||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
return (vae.decode(samples["samples"]), )
|
images = vae.decode(samples["samples"])
|
||||||
|
if len(images.shape) == 5: #Combine batches
|
||||||
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
|
return (images, )
|
||||||
|
|
||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -886,7 +889,7 @@ class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
@ -900,6 +903,8 @@ class CLIPLoader:
|
||||||
clip_type = comfy.sd.CLIPType.SD3
|
clip_type = comfy.sd.CLIPType.SD3
|
||||||
elif type == "stable_audio":
|
elif type == "stable_audio":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||||
|
elif type == "mochi":
|
||||||
|
clip_type = comfy.sd.CLIPType.MOCHI
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
|
@ -2111,6 +2116,7 @@ def init_builtin_extra_nodes():
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
"nodes_lora_extract.py",
|
"nodes_lora_extract.py",
|
||||||
"nodes_torch_compile.py",
|
"nodes_torch_compile.py",
|
||||||
|
"nodes_mochi.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|
Loading…
Reference in New Issue