diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 191c7091..a48f60c7 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -175,3 +175,30 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor + +class Mochi(LatentFormat): + latent_channels = 12 + + def __init__(self): + self.scale_factor = 1.0 + self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, + -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, + 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, + -0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1) + self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617, + 0.959253732819592, 0.8244560132752793, 0.917259975397747, + 0.9294154431013696, 1.3720942357788521, 0.881393668867029, + 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1) + + self.latent_rgb_factors = None #TODO + self.taesd_decoder_name = None #TODO + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 5aebaf9e..cb6b7414 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -13,9 +13,15 @@ try: except: rms_norm_torch = None -def rms_norm(x, weight, eps=1e-6): +def rms_norm(x, weight=None, eps=1e-6): if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) else: - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py new file mode 100644 index 00000000..c36a0006 --- /dev/null +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -0,0 +1,541 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license +#adapted to ComfyUI + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +# from flash_attn import flash_attn_varlen_qkvpacked_func +from comfy.ldm.modules.attention import optimized_attention + +from .layers import ( + FeedForward, + PatchEmbed, + RMSNorm, + TimestepEmbedder, +) + +from .rope_mixed import ( + compute_mixed_rotation, + create_position_matrix, +) +from .temporal_rope import apply_rotary_emb_qk_real +from .utils import ( + AttentionPool, + modulate, +) + +import comfy.ldm.common_dit +import comfy.ops + + +def modulated_rmsnorm(x, scale, eps=1e-6): + # Normalize and modulate + x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) + x_modulated = x_normed * (1 + scale.unsqueeze(1)) + + return x_modulated + + +def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): + # Apply tanh to gate + tanh_gate = torch.tanh(gate).unsqueeze(1) + + # Normalize and apply gated scaling + x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate + + # Apply residual connection + output = x + x_normed + + return output + +class AsymmetricAttention(nn.Module): + def __init__( + self, + dim_x: int, + dim_y: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + update_y: bool = True, + out_bias: bool = True, + attend_to_padding: bool = False, + softmax_scale: Optional[float] = None, + device: Optional[torch.device] = None, + dtype=None, + operations=None, + ): + super().__init__() + self.dim_x = dim_x + self.dim_y = dim_y + self.num_heads = num_heads + self.head_dim = dim_x // num_heads + self.attn_drop = attn_drop + self.update_y = update_y + self.attend_to_padding = attend_to_padding + self.softmax_scale = softmax_scale + if dim_x % num_heads != 0: + raise ValueError( + f"dim_x={dim_x} should be divisible by num_heads={num_heads}" + ) + + # Input layers. + self.qkv_bias = qkv_bias + self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype) + # Project text features to match visual features (dim_y -> dim_x) + self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype) + + # Query and key normalization for stability. + assert qk_norm + self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) + self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) + self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) + self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) + + # Output layers. y features go back down from dim_x -> dim_y. + self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype) + self.proj_y = ( + operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype) + if update_y + else nn.Identity() + ) + + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + crop_y, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + rope_cos = rope_rotation.get("rope_cos") + rope_sin = rope_rotation.get("rope_sin") + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + + # Process visual features + # qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + # assert qkv_x.dtype == torch.bfloat16 + # qkv_x = all_to_all_collect_tokens( + # qkv_x, self.num_heads + # ) # (3, B, N, local_h, head_dim) + + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) + + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) + + # Split qkv_x into q, k, v + q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + + q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2) + k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2) + v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2) + + xy = optimized_attention(q, + k, + v, self.num_heads, skip_reshape=True) + + x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) + x = self.proj_x(x) + o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype) + o[:, :y.shape[1]] = y + + y = self.proj_y(o) + # print("ox", x) + # print("oy", y) + return x, y + + +class AsymmetricJointBlock(nn.Module): + def __init__( + self, + hidden_size_x: int, + hidden_size_y: int, + num_heads: int, + *, + mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. + mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. + update_y: bool = True, # Whether to update text tokens in this block. + device: Optional[torch.device] = None, + dtype=None, + operations=None, + **block_kwargs, + ): + super().__init__() + self.update_y = update_y + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype) + if self.update_y: + self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype) + else: + self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype) + + # Self-attention: + self.attn = AsymmetricAttention( + hidden_size_x, + hidden_size_y, + num_heads=num_heads, + update_y=update_y, + device=device, + dtype=dtype, + operations=operations, + **block_kwargs, + ) + + # MLP. + mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) + assert mlp_hidden_dim_x == int(1536 * 8) + self.mlp_x = FeedForward( + in_features=hidden_size_x, + hidden_size=mlp_hidden_dim_x, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + dtype=dtype, + operations=operations, + ) + + # MLP for text not needed in last block. + if self.update_y: + mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y) + self.mlp_y = FeedForward( + in_features=hidden_size_y, + hidden_size=mlp_hidden_dim_y, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + dtype=dtype, + operations=operations, + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + y: torch.Tensor, + **attn_kwargs, + ): + """Forward pass of a block. + + Args: + x: (B, N, dim) tensor of visual tokens + c: (B, dim) tensor of conditioned features + y: (B, L, dim) tensor of text tokens + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim) tensor of visual tokens after block + y: (B, L, dim) tensor of text tokens after block + """ + N = x.size(1) + + c = F.silu(c) + mod_x = self.mod_x(c) + scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) + + mod_y = self.mod_y(c) + if self.update_y: + scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) + else: + scale_msa_y = mod_y + + # Self-attention block. + x_attn, y_attn = self.attn( + x, + y, + scale_x=scale_msa_x, + scale_y=scale_msa_y, + **attn_kwargs, + ) + + assert x_attn.size(1) == N + x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) + if self.update_y: + y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) + + # MLP block. + x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) + if self.update_y: + y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) + + return x, y + + def ff_block_x(self, x, scale_x, gate_x): + x_mod = modulated_rmsnorm(x, scale_x) + x_res = self.mlp_x(x_mod) + x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm + return x + + def ff_block_y(self, y, scale_y, gate_y): + y_mod = modulated_rmsnorm(y, scale_y) + y_res = self.mlp_y(y_mod) + y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm + return y + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size, + patch_size, + out_channels, + device: Optional[torch.device] = None, + dtype=None, + operations=None, + ): + super().__init__() + self.norm_final = operations.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype + ) + self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype) + self.linear = operations.Linear( + hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype + ) + + def forward(self, x, c): + c = F.silu(c) + shift, scale = self.mod(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AsymmDiTJoint(nn.Module): + """ + Diffusion model with a Transformer backbone. + + Ingests text embeddings instead of a label. + """ + + def __init__( + self, + *, + patch_size=2, + in_channels=4, + hidden_size_x=1152, + hidden_size_y=1152, + depth=48, + num_heads=16, + mlp_ratio_x=8.0, + mlp_ratio_y=4.0, + use_t5: bool = False, + t5_feat_dim: int = 4096, + t5_token_length: int = 256, + learn_sigma=True, + patch_embed_bias: bool = True, + timestep_mlp_bias: bool = True, + attend_to_padding: bool = False, + timestep_scale: Optional[float] = None, + use_extended_posenc: bool = False, + posenc_preserve_area: bool = False, + rope_theta: float = 10000.0, + image_model=None, + device: Optional[torch.device] = None, + dtype=None, + operations=None, + **block_kwargs, + ): + super().__init__() + + self.dtype = dtype + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.head_dim = ( + hidden_size_x // num_heads + ) # Head dimension and count is determined by visual. + self.attend_to_padding = attend_to_padding + self.use_extended_posenc = use_extended_posenc + self.posenc_preserve_area = posenc_preserve_area + self.use_t5 = use_t5 + self.t5_token_length = t5_token_length + self.t5_feat_dim = t5_feat_dim + self.rope_theta = ( + rope_theta # Scaling factor for frequency computation for temporal RoPE. + ) + + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size_x, + bias=patch_embed_bias, + dtype=dtype, + device=device, + operations=operations + ) + # Conditionings + # Timestep + self.t_embedder = TimestepEmbedder( + hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations + ) + + if self.use_t5: + # Caption Pooling (T5) + self.t5_y_embedder = AttentionPool( + t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations + ) + + # Dense Embedding Projection (T5) + self.t5_yproj = operations.Linear( + t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device + ) + + # Initialize pos_frequencies as an empty parameter. + self.pos_frequencies = nn.Parameter( + torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device) + ) + + assert not self.attend_to_padding + + # for depth 48: + # b = 0: AsymmetricJointBlock, update_y=True + # b = 1: AsymmetricJointBlock, update_y=True + # ... + # b = 46: AsymmetricJointBlock, update_y=True + # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features. + blocks = [] + for b in range(depth): + # Joint multi-modal block + update_y = b < depth - 1 + block = AsymmetricJointBlock( + hidden_size_x, + hidden_size_y, + num_heads, + mlp_ratio_x=mlp_ratio_x, + mlp_ratio_y=mlp_ratio_y, + update_y=update_y, + attend_to_padding=attend_to_padding, + device=device, + dtype=dtype, + operations=operations, + **block_kwargs, + ) + + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + self.final_layer = FinalLayer( + hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations + ) + + def embed_x(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C=12, T, H, W) tensor of visual tokens + + Returns: + x: (B, C=3072, N) tensor of visual tokens with positional embedding. + """ + return self.x_embedder(x) # Convert BcTHW to BCN + + def prepare( + self, + x: torch.Tensor, + sigma: torch.Tensor, + t5_feat: torch.Tensor, + t5_mask: torch.Tensor, + ): + """Prepare input and conditioning embeddings.""" + # Visual patch embeddings with positional encoding. + T, H, W = x.shape[-3:] + pH, pW = H // self.patch_size, W // self.patch_size + x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 + assert x.ndim == 3 + B = x.size(0) + + + pH, pW = H // self.patch_size, W // self.patch_size + N = T * pH * pW + assert x.size(1) == N + pos = create_position_matrix( + T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 + ) # (N, 3) + rope_cos, rope_sin = compute_mixed_rotation( + freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos + ) # Each are (N, num_heads, dim // 2) + + c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D) + + t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) + + c = c_t + t5_y_pool + + y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D) + + return x, c, y_feat, rope_cos, rope_sin + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: List[torch.Tensor], + attention_mask: List[torch.Tensor], + num_tokens=256, + packed_indices: Dict[str, torch.Tensor] = None, + rope_cos: torch.Tensor = None, + rope_sin: torch.Tensor = None, + control=None, **kwargs + ): + y_feat = context + y_mask = attention_mask + sigma = timestep + """Forward pass of DiT. + + Args: + x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) + sigma: (B,) tensor of noise standard deviations + y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_mask: List((B, L) boolean tensor indicating which tokens are not padding) + packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. + """ + B, _, T, H, W = x.shape + + x, c, y_feat, rope_cos, rope_sin = self.prepare( + x, sigma, y_feat, y_mask + ) + del y_mask + + for i, block in enumerate(self.blocks): + x, y_feat = block( + x, + c, + y_feat, + rope_cos=rope_cos, + rope_sin=rope_sin, + crop_y=num_tokens, + ) # (B, M, D), (B, L, D) + del y_feat # Final layers don't use dense text features. + + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + x = rearrange( + x, + "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", + T=T, + hp=H // self.patch_size, + wp=W // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + c=self.out_channels, + ) + + return -x diff --git a/comfy/ldm/genmo/joint_model/layers.py b/comfy/ldm/genmo/joint_model/layers.py new file mode 100644 index 00000000..51d97955 --- /dev/null +++ b/comfy/ldm/genmo/joint_model/layers.py @@ -0,0 +1,164 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license +#adapted to ComfyUI + +import collections.abc +import math +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import comfy.ldm.common_dit + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class TimestepEmbedder(nn.Module): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + *, + bias: bool = True, + timestep_scale: Optional[float] = None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.timestep_scale = timestep_scale + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + freqs.mul_(-math.log(max_period) / half).exp_() + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t, out_dtype): + if self.timestep_scale is not None: + t = t * self.timestep_scale + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class FeedForward(nn.Module): + def __init__( + self, + in_features: int, + hidden_size: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + device: Optional[torch.device] = None, + dtype=None, + operations=None, + ): + super().__init__() + # keep parameter count and computation constant compared to standard FFN + hidden_size = int(2 * hidden_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_size = int(ffn_dim_multiplier * hidden_size) + hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of) + + self.hidden_dim = hidden_size + self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype) + self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype) + + def forward(self, x): + x, gate = self.w1(x).chunk(2, dim=-1) + x = self.w2(F.silu(x) * gate) + return x + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + bias: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.flatten = flatten + self.dynamic_img_pad = dynamic_img_pad + + self.proj = operations.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + device=device, + dtype=dtype, + ) + assert norm_layer is None + self.norm = ( + norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() + ) + + def forward(self, x): + B, _C, T, H, W = x.shape + if not self.dynamic_img_pad: + assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + else: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular') + x = self.proj(x) + + # Flatten temporal and spatial dimensions. + if not self.flatten: + raise NotImplementedError("Must flatten output.") + x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) + + x = self.norm(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype)) + self.register_parameter("bias", None) + + def forward(self, x): + return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) diff --git a/comfy/ldm/genmo/joint_model/rope_mixed.py b/comfy/ldm/genmo/joint_model/rope_mixed.py new file mode 100644 index 00000000..dee3fa21 --- /dev/null +++ b/comfy/ldm/genmo/joint_model/rope_mixed.py @@ -0,0 +1,88 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license + +# import functools +import math + +import torch + + +def centers(start: float, stop, num, dtype=None, device=None): + """linspace through bin centers. + + Args: + start (float): Start of the range. + stop (float): End of the range. + num (int): Number of points. + dtype (torch.dtype): Data type of the points. + device (torch.device): Device of the points. + + Returns: + centers (Tensor): Centers of the bins. Shape: (num,). + """ + edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) + return (edges[:-1] + edges[1:]) / 2 + + +# @functools.lru_cache(maxsize=1) +def create_position_matrix( + T: int, + pH: int, + pW: int, + device: torch.device, + dtype: torch.dtype, + *, + target_area: float = 36864, +): + """ + Args: + T: int - Temporal dimension + pH: int - Height dimension after patchify + pW: int - Width dimension after patchify + + Returns: + pos: [T * pH * pW, 3] - position matrix + """ + # Create 1D tensors for each dimension + t = torch.arange(T, dtype=dtype) + + # Positionally interpolate to area 36864. + # (3072x3072 frame with 16x16 patches = 192x192 latents). + # This automatically scales rope positions when the resolution changes. + # We use a large target area so the model is more sensitive + # to changes in the learned pos_frequencies matrix. + scale = math.sqrt(target_area / (pW * pH)) + w = centers(-pW * scale / 2, pW * scale / 2, pW) + h = centers(-pH * scale / 2, pH * scale / 2, pH) + + # Use meshgrid to create 3D grids + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + # Stack and reshape the grids. + pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] + pos = pos.view(-1, 3) # [T * pH * pW, 3] + pos = pos.to(dtype=dtype, device=device) + + return pos + + +def compute_mixed_rotation( + freqs: torch.Tensor, + pos: torch.Tensor, +): + """ + Project each 3-dim position into per-head, per-head-dim 1D frequencies. + + Args: + freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position + pos: [N, 3] - position of each token + num_heads: int + + Returns: + freqs_cos: [N, num_heads, num_freqs] - cosine components + freqs_sin: [N, num_heads, num_freqs] - sine components + """ + assert freqs.ndim == 3 + freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) + freqs_cos = torch.cos(freqs_sum) + freqs_sin = torch.sin(freqs_sum) + return freqs_cos, freqs_sin diff --git a/comfy/ldm/genmo/joint_model/temporal_rope.py b/comfy/ldm/genmo/joint_model/temporal_rope.py new file mode 100644 index 00000000..88f5d6d2 --- /dev/null +++ b/comfy/ldm/genmo/joint_model/temporal_rope.py @@ -0,0 +1,34 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license + +# Based on Llama3 Implementation. +import torch + + +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. + + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. + + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] + + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) + + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + return out diff --git a/comfy/ldm/genmo/joint_model/utils.py b/comfy/ldm/genmo/joint_model/utils.py new file mode 100644 index 00000000..41190242 --- /dev/null +++ b/comfy/ldm/genmo/joint_model/utils.py @@ -0,0 +1,102 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license +#adapted to ComfyUI + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + +class AttentionPool(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + output_dim: int = None, + device: Optional[torch.device] = None, + dtype=None, + operations=None, + ): + """ + Args: + spatial_dim (int): Number of tokens in sequence length. + embed_dim (int): Dimensionality of input tokens. + num_heads (int): Number of attention heads. + output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. + """ + super().__init__() + self.num_heads = num_heads + self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype) + self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask): + """ + Args: + x (torch.Tensor): (B, L, D) tensor of input tokens. + mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. + + NOTE: We assume x does not require gradients. + + Returns: + x (torch.Tensor): (B, D) tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_heads + kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x diff --git a/comfy/ldm/genmo/vae/model.py b/comfy/ldm/genmo/vae/model.py new file mode 100644 index 00000000..e44c08a4 --- /dev/null +++ b/comfy/ldm/genmo/vae/model.py @@ -0,0 +1,480 @@ +#original code from https://github.com/genmoai/models under apache 2.0 license +#adapted to ComfyUI + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +import comfy.ops +ops = comfy.ops.disable_weight_init + +# import mochi_preview.dit.joint_model.context_parallel as cp +# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class GroupNormSpatial(ops.GroupNorm): + """ + GroupNorm applied per-frame. + """ + + def forward(self, x: torch.Tensor, *, chunk_size: int = 8): + B, C, T, H, W = x.shape + x = rearrange(x, "B C T H W -> (B T) C H W") + # Run group norm in chunks. + output = torch.empty_like(x) + for b in range(0, B * T, chunk_size): + output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) + return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) + +class PConv3d(ops.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + causal: bool = True, + context_parallel: bool = True, + **kwargs, + ): + self.causal = causal + self.context_parallel = context_parallel + kernel_size = cast_tuple(kernel_size, 3) + stride = cast_tuple(stride, 3) + height_pad = (kernel_size[1] - 1) // 2 + width_pad = (kernel_size[2] - 1) // 2 + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + **kwargs, + ) + + def forward(self, x: torch.Tensor): + # Compute padding amounts. + context_size = self.kernel_size[0] - 1 + if self.causal: + pad_front = context_size + pad_back = 0 + else: + pad_front = context_size // 2 + pad_back = context_size - pad_front + + # Apply padding. + assert self.padding_mode == "replicate" # DEBUG + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + return super().forward(x) + + +class Conv1x1(ops.Linear): + """*1x1 Conv implemented with a linear layer.""" + + def __init__(self, in_features: int, out_features: int, *args, **kwargs): + super().__init__(in_features, out_features, *args, **kwargs) + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, *] or [B, *, C]. + + Returns: + x: Output tensor. Shape: [B, C', *] or [B, *, C']. + """ + x = x.movedim(1, -1) + x = super().forward(x) + x = x.movedim(-1, 1) + return x + + +class DepthToSpaceTime(nn.Module): + def __init__( + self, + temporal_expansion: int, + spatial_expansion: int, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # When printed, this module should show the temporal and spatial expansion factors. + def extra_repr(self): + return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}" + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + + Returns: + x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s]. + """ + x = rearrange( + x, + "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)", + st=self.temporal_expansion, + sh=self.spatial_expansion, + sw=self.spatial_expansion, + ) + + # cp_rank, _ = cp.get_cp_rank_size() + if self.temporal_expansion > 1: # and cp_rank == 0: + # Drop the first self.temporal_expansion - 1 frames. + # This is because we always want the 3x3x3 conv filter to only apply + # to the first frame, and the first frame doesn't need to be repeated. + assert all(x.shape) + x = x[:, :, self.temporal_expansion - 1 :] + assert all(x.shape) + + return x + + +def norm_fn( + in_channels: int, + affine: bool = True, +): + return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) + + +class ResBlock(nn.Module): + """Residual block that preserves the spatial dimensions.""" + + def __init__( + self, + channels: int, + *, + affine: bool = True, + attn_block: Optional[nn.Module] = None, + padding_mode: str = "replicate", + causal: bool = True, + ): + super().__init__() + self.channels = channels + + assert causal + self.stack = nn.Sequential( + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + PConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + # causal=causal, + ), + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + PConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + # causal=causal, + ), + ) + + self.attn_block = attn_block if attn_block else nn.Identity() + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + """ + residual = x + x = self.stack(x) + x = x + residual + del residual + + return self.attn_block(x) + + +class CausalUpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + *, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + **block_kwargs, + ): + super().__init__() + + blocks = [] + for _ in range(num_res_blocks): + blocks.append(block_fn(in_channels, **block_kwargs)) + self.blocks = nn.Sequential(*blocks) + + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # Change channels in the final convolution layer. + self.proj = Conv1x1( + in_channels, + out_channels * temporal_expansion * (spatial_expansion**2), + ) + + self.d2st = DepthToSpaceTime( + temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion + ) + + def forward(self, x): + x = self.blocks(x) + x = self.proj(x) + x = self.d2st(x) + return x + + +def block_fn(channels, *, has_attention: bool = False, **block_kwargs): + assert has_attention is False #NOTE: if this is ever true add back the attention code. + + attn_block = None #AttentionBlock(channels) if has_attention else None + + return ResBlock( + channels, affine=True, attn_block=attn_block, **block_kwargs + ) + + +class DownsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks, + *, + temporal_reduction=2, + spatial_reduction=2, + **block_kwargs, + ): + """ + Downsample block for the VAE encoder. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks. + temporal_reduction: Temporal reduction factor. + spatial_reduction: Spatial reduction factor. + """ + super().__init__() + layers = [] + + # Change the channel count in the strided convolution. + # This lets the ResBlock have uniform channel count, + # as in ConvNeXt. + assert in_channels != out_channels + layers.append( + PConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), + stride=(temporal_reduction, spatial_reduction, spatial_reduction), + padding_mode="replicate", + bias=True, + ) + ) + + for _ in range(num_res_blocks): + layers.append(block_fn(out_channels, **block_kwargs)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): + num_freqs = (stop - start) // step + assert inputs.ndim == 5 + C = inputs.size(1) + + # Create Base 2 Fourier features. + freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device) + assert num_freqs == len(freqs) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + C = inputs.shape[1] + w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w. + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat( + [ + inputs, + torch.sin(h), + torch.cos(h), + ], + dim=1, + ) + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs): + """Add Fourier features to inputs. + + Args: + inputs: Input tensor. Shape: [B, C, T, H, W] + + Returns: + h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W] + """ + return add_fourier_features(inputs, self.start, self.stop, self.step) + + +class Decoder(nn.Module): + def __init__( + self, + *, + out_channels: int = 3, + latent_dim: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + temporal_expansions: Optional[List[int]] = None, + spatial_expansions: Optional[List[int]] = None, + has_attention: List[bool], + output_norm: bool = True, + nonlinearity: str = "silu", + output_nonlinearity: str = "silu", + causal: bool = True, + **block_kwargs, + ): + super().__init__() + self.input_channels = latent_dim + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.output_nonlinearity = output_nonlinearity + assert nonlinearity == "silu" + assert causal + + ch = [mult * base_channels for mult in channel_multipliers] + self.num_up_blocks = len(ch) - 1 + assert len(num_res_blocks) == self.num_up_blocks + 2 + + blocks = [] + + first_block = [ + nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) + ] # Input layer. + # First set of blocks preserve channel count. + for _ in range(num_res_blocks[-1]): + first_block.append( + block_fn( + ch[-1], + has_attention=has_attention[-1], + causal=causal, + **block_kwargs, + ) + ) + blocks.append(nn.Sequential(*first_block)) + + assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks + assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2 + + upsample_block_fn = CausalUpsampleBlock + + for i in range(self.num_up_blocks): + block = upsample_block_fn( + ch[-i - 1], + ch[-i - 2], + num_res_blocks=num_res_blocks[-i - 2], + has_attention=has_attention[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + causal=causal, + **block_kwargs, + ) + blocks.append(block) + + assert not output_norm + + # Last block. Preserve channel count. + last_block = [] + for _ in range(num_res_blocks[0]): + last_block.append( + block_fn( + ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs + ) + ) + blocks.append(nn.Sequential(*last_block)) + + self.blocks = nn.ModuleList(blocks) + self.output_proj = Conv1x1(ch[0], out_channels) + + def forward(self, x): + """Forward pass. + + Args: + x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1]. + + Returns: + x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]. + T + 1 = (t - 1) * 4. + H = h * 16, W = w * 16. + """ + for block in self.blocks: + x = block(x) + + if self.output_nonlinearity == "silu": + x = F.silu(x, inplace=not self.training) + else: + assert ( + not self.output_nonlinearity + ) # StyleGAN3 omits the to-RGB nonlinearity. + + return self.output_proj(x).contiguous() + + +class VideoVAE(nn.Module): + def __init__(self): + super().__init__() + self.encoder = None #TODO once the model releases + self.decoder = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + padding_mode="replicate", + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index b085bbc0..a160b2f4 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -97,7 +97,7 @@ class PatchEmbed(nn.Module): self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): - B, C, H, W = x.shape + # B, C, H, W = x.shape # if self.img_size is not None: # if self.strict_img_size: # _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") diff --git a/comfy/model_base.py b/comfy/model_base.py index 5138d2b9..f2833168 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper +import comfy.ldm.genmo.joint_model.asymm_models_joint import comfy.ldm.aura.mmdit import comfy.ldm.hydit.models import comfy.ldm.audio.dit @@ -718,3 +719,18 @@ class Flux(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) return out + +class GenmoMochi(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e1d29db3..5d2abe1b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -145,6 +145,34 @@ def detect_unet_config(state_dict, key_prefix): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview + dit_config = {} + dit_config["image_model"] = "mochi_preview" + dit_config["depth"] = 48 + dit_config["patch_size"] = 2 + dit_config["num_heads"] = 24 + dit_config["hidden_size_x"] = 3072 + dit_config["hidden_size_y"] = 1536 + dit_config["mlp_ratio_x"] = 4.0 + dit_config["mlp_ratio_y"] = 4.0 + dit_config["learn_sigma"] = False + dit_config["in_channels"] = 12 + dit_config["qk_norm"] = True + dit_config["qkv_bias"] = False + dit_config["out_bias"] = True + dit_config["attn_drop"] = 0.0 + dit_config["patch_embed_bias"] = True + dit_config["posenc_preserve_area"] = True + dit_config["timestep_mlp_bias"] = True + dit_config["attend_to_padding"] = False + dit_config["timestep_scale"] = 1000.0 + dit_config["use_t5"] = True + dit_config["t5_feat_dim"] = 4096 + dit_config["t5_token_length"] = 256 + dit_config["rope_theta"] = 10000.0 + return dit_config + + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/samplers.py b/comfy/samplers.py index f85bd203..94cba03b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -366,6 +366,27 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6): sigs += [0.0] return torch.FloatTensor(sigs) +# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41 +def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None): + if steps == 1: + sigma_schedule = [1.0, 0.0] + else: + if linear_steps is None: + linear_steps = steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * steps + quadratic_steps = steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) + const = quadratic_coef * (linear_steps ** 2) + quadratic_sigma_schedule = [ + quadratic_coef * (i ** 2) + linear_coef * i + const + for i in range(linear_steps, steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu() + def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -732,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def calculate_sigmas(model_sampling, scheduler_name, steps): @@ -750,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps): sigmas = normal_scheduler(model_sampling, steps, sgm=True) elif scheduler_name == "beta": sigmas = beta_scheduler(model_sampling, steps) + elif scheduler_name == "linear_quadratic": + sigmas = linear_quadratic_schedule(model_sampling, steps) else: logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas diff --git a/comfy/sd.py b/comfy/sd.py index e4abf0b9..a65382b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE +import comfy.ldm.genmo.vae.model import yaml import comfy.utils @@ -25,6 +26,7 @@ import comfy.text_encoders.aura_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.long_clipl +import comfy.text_encoders.genmo import comfy.model_patcher import comfy.lora @@ -241,6 +243,13 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae + if "blocks.2.blocks.3.stack.5.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) + self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE() + self.latent_channels = 12 + self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) + self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -296,6 +305,10 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device) + def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -314,6 +327,7 @@ class VAE: return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) def decode(self, samples_in): + pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used) @@ -321,16 +335,21 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION as e: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - if len(samples_in.shape) == 3: + dims = samples_in.ndim - 2 + if dims == 1: pixel_samples = self.decode_tiled_1d(samples_in) - else: + elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + pixel_samples = self.decode_tiled_3d(samples_in) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -398,6 +417,7 @@ class CLIPType(Enum): STABLE_AUDIO = 4 HUNYUAN_DIT = 5 FLUX = 6 + MOCHI = 7 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] @@ -474,8 +494,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer elif te_model == TEModel.T5_XXL: - clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + if clip_type == CLIPType.SD3: + clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + else: #CLIPType.MOCHI + clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer elif te_model == TEModel.T5_XL: clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index da39ccf4..57099082 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux +import comfy.text_encoders.genmo from . import supported_models_base from . import latent_formats @@ -670,7 +671,36 @@ class FluxSchnell(Flux): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class GenmoMochi(supported_models_base.BASE): + unet_config = { + "image_model": "mochi_preview", + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell] + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Mochi + + memory_usage_factor = 2.0 #TODO + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.GenmoMochi(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect)) + + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi] models += [SVD_img2vid] diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py new file mode 100644 index 00000000..5e96cea6 --- /dev/null +++ b/comfy/text_encoders/genmo.py @@ -0,0 +1,38 @@ +from comfy import sd1_clip +import comfy.text_encoders.sd3_clip +import os +from transformers import T5TokenizerFast + + +class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): + def __init__(self, **kwargs): + kwargs["attention_mask"] = True + super().__init__(**kwargs) + + +class MochiT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + + +class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + + +def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): + class MochiTEModel_(MochiT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return MochiTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 056cf363..06c81107 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -731,7 +731,27 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): dims = len(tile) - output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) + + if not (isinstance(upscale_amount, (tuple, list))): + upscale_amount = [upscale_amount] * dims + + if not (isinstance(overlap, (tuple, list))): + overlap = [overlap] * dims + + def get_upscale(dim, val): + up = upscale_amount[dim] + if callable(up): + return up(val) + else: + return up * val + + def mult_list_upscale(a): + out = [] + for i in range(len(a)): + out.append(round(get_upscale(i, a[i]))) + return out + + output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] @@ -743,27 +763,27 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_ pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) + out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] + positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] for it in itertools.product(*positions): s_in = s upscaled = [] for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap, it[d])) + pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(pos * upscale_amount)) + upscaled.append(round(get_upscale(d, pos))) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) - feather = round(overlap * upscale_amount) - for t in range(feather): - for d in range(2, dims + 2): + for d in range(2, dims + 2): + feather = round(get_upscale(d - 2, overlap[d - 2])) + for t in range(feather): a = (t + 1) / feather mask.narrow(d, t, 1).mul_(a) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py new file mode 100644 index 00000000..4cbbea09 --- /dev/null +++ b/comfy_extras/nodes_mochi.py @@ -0,0 +1,26 @@ +import nodes +import torch +import comfy.model_management + +class EmptyMochiLatentVideo: + def __init__(self): + self.device = comfy.model_management.intermediate_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/mochi" + + def generate(self, width, height, length, batch_size=1): + latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device) + return ({"samples":latent}, ) + +NODE_CLASS_MAPPINGS = { + "EmptyMochiLatentVideo": EmptyMochiLatentVideo, +} diff --git a/nodes.py b/nodes.py index ff45acf8..c81a0af1 100644 --- a/nodes.py +++ b/nodes.py @@ -281,7 +281,10 @@ class VAEDecode: DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - return (vae.decode(samples["samples"]), ) + images = vae.decode(samples["samples"]) + if len(images.shape) == 5: #Combine batches + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return (images, ) class VAEDecodeTiled: @classmethod @@ -886,7 +889,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -900,6 +903,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.SD3 elif type == "stable_audio": clip_type = comfy.sd.CLIPType.STABLE_AUDIO + elif type == "mochi": + clip_type = comfy.sd.CLIPType.MOCHI else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION @@ -2111,6 +2116,7 @@ def init_builtin_extra_nodes(): "nodes_flux.py", "nodes_lora_extract.py", "nodes_torch_compile.py", + "nodes_mochi.py", ] import_failed = []