diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a9a9620..92f39d5c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -135,3 +135,6 @@ class SD3(LatentFormat): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor + +class StableAudio1(LatentFormat): + latent_channels = 64 diff --git a/comfy/ldm/audio/autoencoder.py b/comfy/ldm/audio/autoencoder.py new file mode 100644 index 00000000..7363131e --- /dev/null +++ b/comfy/ldm/audio/autoencoder.py @@ -0,0 +1,276 @@ +# code adapted from: https://github.com/Stability-AI/stable-audio-tools + +import torch +from torch import nn +from typing import Literal, Dict, Any +import math +import comfy.ops +ops = comfy.ops.disable_weight_init + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(nn.Module): + def __init__(self): + super().__init__() + self.is_discrete = False + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + # self.alpha.requires_grad = alpha_trainable + # self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x + +def WNConv1d(*args, **kwargs): + return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) + +def WNConvTranspose1d(*args, **kwargs): + return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = torch.nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = torch.nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class AudioOobleckVAE(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=64, + c_mults = [1, 2, 4, 8, 16], + strides = [2, 4, 4, 8, 8], + use_snake=True, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=False): + super().__init__() + self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation) + self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation, + use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh) + self.bottleneck = VAEBottleneck() + + def encode(self, x): + return self.bottleneck.encode(self.encoder(x)) + + def decode(self, x): + return self.decoder(self.bottleneck.decode(x)) + diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py new file mode 100644 index 00000000..1c1112c5 --- /dev/null +++ b/comfy/ldm/audio/dit.py @@ -0,0 +1,888 @@ +# code adapted from: https://github.com/Stability-AI/stable-audio-tools + +from comfy.ldm.modules.attention import optimized_attention +import typing as tp + +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F +import math + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1., dtype=None, device=None): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.empty( + [out_features // 2, in_features], dtype=dtype, device=device)) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device) + return torch.cat([f.cos(), f.sin()], dim=-1) + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + if bias: + self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + else: + self.beta = None + + def forward(self, x): + beta = self.beta + if self.beta is not None: + beta = beta.to(dtype=x.dtype, device=x.device) + return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta) + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation, + use_conv = False, + conv_kernel_size = 3, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.act = activation + self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = torch.einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len, device, dtype): + # device = self.inv_freq.device + + t = torch.arange(seq_len, device=device, dtype=dtype) + return self.forward(t) + + def forward(self, t): + # device = self.inv_freq.device + device = t.device + dtype = t.dtype + + # t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device)) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device) + + # # init last linear layer to 0 + # if zero_init_output: + # nn.init.zeros_(linear_out.weight) + # if not no_bias: + # nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm = False, + natten_kernel_size = None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) + else: + self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) + + self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + # if zero_init_output: + # nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + + def forward( + self, + x, + context = None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + causal = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm: + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if h != kv_h: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = h // kv_h + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + out = optimized_attention(q, k, v, h, skip_reshape=True) + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {}, + dtype=None, + device=None, + operations=None, + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads = dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + dtype=dtype, + device=device, + operations=operations, + **attn_kwargs + ) + + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads = dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + dtype=dtype, + device=device, + operations=operations, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + + def forward( + self, + x, + context = None, + global_cond=None, + mask = None, + context_mask = None, + rotary_pos_emb = None + ): + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + dtype=None, + device=None, + operations=None, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity() + self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + dtype=dtype, + device=device, + operations=operations, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + global_cond = None, + return_info = False, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + + mask = torch.cat((prepend_mask, mask), dim = -1) + + # Attention layers + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + for layer in self.layers: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + x = self.project_out(x) + + if return_info: + return x, info + + return x + +class AudioDiffusionTransformer(nn.Module): + def __init__(self, + io_channels=64, + patch_size=1, + embed_dim=1536, + cond_token_dim=768, + project_cond_tokens=False, + global_cond_dim=1536, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=24, + num_heads=24, + transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + audio_model="", + dtype=None, + device=None, + operations=None, + **kwargs): + + super().__init__() + + self.dtype = dtype + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) + + self.to_timestep_embed = nn.Sequential( + operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device), + ) + + if cond_token_dim > 0: + # Conditioning tokens + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + dtype=dtype, + device=device, + operations=operations, + **kwargs + ) + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device) + self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + return_info=False, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim) + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) + elif self.transformer_type == "continuous_transformer": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + elif self.transformer_type == "mm_transformer": + output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + timestep, + context=None, + context_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + mask=None, + return_info=False, + control=None, + transformer_options={}, + **kwargs): + return self._forward( + x, + timestep, + cross_attn_cond=context, + cross_attn_cond_mask=context_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py new file mode 100644 index 00000000..82a3210c --- /dev/null +++ b/comfy/ldm/audio/embedders.py @@ -0,0 +1,108 @@ +# code adapted from: https://github.com/Stability-AI/stable-audio-tools + +import torch +import torch.nn as nn +from torch import Tensor, einsum +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from einops import rearrange +import math +import comfy.ops + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.empty(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features), + ) + + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x): + raise NotImplementedError() + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats, device=None): + # Cast the inputs to floats + floats = [float(x) for x in floats] + + if device is None: + device = next(self.embedder.parameters()).device + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index da9f7aab..65a8bcf4 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -86,22 +86,32 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) - b, _, dim_head = q.shape - dim_head //= heads + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 h = heads - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, -1, heads, dim_head) - .permute(0, 2, 1, 3) - .reshape(b * heads, -1, dim_head) - .contiguous(), - (q, k, v), - ) + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) # force cast to fp32 to avoid overflowing if attn_precision == torch.float32: @@ -138,17 +148,26 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None): return out -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) - b, _, dim_head = query.shape - dim_head //= heads + if skip_reshape: + b, _, _, dim_head = query.shape + else: + b, _, dim_head = query.shape + dim_head //= heads scale = dim_head ** -0.5 - query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) - value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) - key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) + if skip_reshape: + query = query.reshape(b * heads, -1, dim_head) + value = value.reshape(b * heads, -1, dim_head) + key = key.reshape(b * heads, -1, dim_head).movedim(1, 2) + else: + query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) + dtype = query.dtype upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 @@ -200,22 +219,32 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None): +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) - b, _, dim_head = q.shape - dim_head //= heads + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 h = heads - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, -1, heads, dim_head) - .permute(0, 2, 1, 3) - .reshape(b * heads, -1, dim_head) - .contiguous(), - (q, k, v), - ) + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) @@ -311,9 +340,12 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): - b, _, dim_head = q.shape - dim_head //= heads +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads disabled_xformers = False @@ -328,10 +360,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): if disabled_xformers: return attention_pytorch(q, k, v, heads, mask) - q, k, v = map( - lambda t: t.reshape(b, -1, heads, dim_head), - (q, k, v), - ) + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.reshape(b, -1, heads, dim_head), + (q, k, v), + ) if mask is not None: pad = 8 - q.shape[1] % 8 @@ -341,18 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - out = ( - out.reshape(b, -1, heads * dim_head) - ) + if skip_reshape: + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + else: + out = ( + out.reshape(b, -1, heads * dim_head) + ) + return out -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None): - b, _, dim_head = q.shape - dim_head //= heads - q, k, v = map( - lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), - (q, k, v), - ) +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = ( diff --git a/comfy/model_base.py b/comfy/model_base.py index daff6e0f..f45b375d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,12 +6,15 @@ from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper +import comfy.ldm.audio.dit +import comfy.ldm.audio.embedders import comfy.model_management import comfy.conds import comfy.ops from enum import Enum from . import utils import comfy.latent_formats +import math class ModelType(Enum): EPS = 1 @@ -20,9 +23,10 @@ class ModelType(Enum): STABLE_CASCADE = 4 EDM = 5 FLOW = 6 + V_PREDICTION_CONTINUOUS = 7 -from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV def model_sampling(model_config, model_type): @@ -44,6 +48,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.EDM: c = EDM s = ModelSamplingContinuousEDM + elif model_type == ModelType.V_PREDICTION_CONTINUOUS: + c = V_PREDICTION + s = ModelSamplingContinuousV class ModelSampling(s, c): pass @@ -236,11 +243,11 @@ class BaseModel(torch.nn.Module): if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype #TODO: this needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] + area = input_shape[0] * math.prod(input_shape[2:]) return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. - area = input_shape[0] * input_shape[2] * input_shape[3] + area = input_shape[0] * math.prod(input_shape[2:]) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) @@ -590,3 +597,33 @@ class SD3(BaseModel): else: area = input_shape[0] * input_shape[2] * input_shape[3] return (area * 0.3) * (1024 * 1024) + + +class StableAudio1(BaseModel): + def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) + self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) + self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) + self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights) + self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + + def extra_conds(self, **kwargs): + out = {} + + noise = kwargs.get("noise", None) + device = kwargs["device"] + + seconds_start = kwargs.get("seconds_start", 0) + seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53)) + + seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device) + seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device) + + global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1)) + out['global_embed'] = comfy.conds.CONDRegular(global_embed) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dfe0ea99..4843e6a4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit + unet_config = {} + unet_config["audio_model"] = "dit1.0" + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: return model_config +def unet_prefix_from_state_dict(state_dict): + if "model.model.postprocess_conv.weight" in state_dict: #audio models + unet_key_prefix = "model.model." + else: + unet_key_prefix = "model.diffusion_model." + return unet_key_prefix + def convert_config(unet_config): new_config = unet_config.copy() num_res_blocks = new_config.get("num_res_blocks", None) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index d6120a83..6bd3a5d7 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -169,6 +169,14 @@ class ModelSamplingContinuousEDM(torch.nn.Module): return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) +class ModelSamplingContinuousV(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma.atan() / math.pi * 2 + + def sigma(self, timestep): + return (timestep * math.pi / 2).tan() + + def time_snr_shift(alpha, t): if alpha == 1.0: return t diff --git a/comfy/ops.py b/comfy/ops.py index 7ebb3dd2..0f1ceb57 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -51,6 +51,20 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class Conv1d(torch.nn.Conv1d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -133,6 +147,27 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 1 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose1d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -147,6 +182,9 @@ class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): comfy_cast_weights = True + class Conv1d(disable_weight_init.Conv1d): + comfy_cast_weights = True + class Conv2d(disable_weight_init.Conv2d): comfy_cast_weights = True @@ -161,3 +199,6 @@ class manual_cast(disable_weight_init): class ConvTranspose2d(disable_weight_init.ConvTranspose2d): comfy_cast_weights = True + + class ConvTranspose1d(disable_weight_init.ConvTranspose1d): + comfy_cast_weights = True diff --git a/comfy/sa_t5.py b/comfy/sa_t5.py new file mode 100644 index 00000000..37be5287 --- /dev/null +++ b/comfy/sa_t5.py @@ -0,0 +1,22 @@ +from comfy import sd1_clip +from transformers import T5TokenizerFast +import comfy.t5 +import os + +class T5BaseModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) + +class T5BaseTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + +class SAT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer) + +class SAT5Model(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs) diff --git a/comfy/sd.py b/comfy/sd.py index 3fd9e0e9..f1e48713 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,7 +6,7 @@ from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder - +from .ldm.audio.autoencoder import AudioOobleckVAE import yaml import comfy.utils @@ -20,6 +20,7 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip from . import sd3_clip +from . import sa_t5 import comfy.model_patcher import comfy.lora @@ -174,6 +175,7 @@ class VAE: self.downscale_ratio = 8 self.upscale_ratio = 8 self.latent_channels = 4 + self.output_channels = 3 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) @@ -232,6 +234,16 @@ class VAE: self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + elif "decoder.layers.0.weight_v" in sd: + self.first_stage_model = AudioOobleckVAE() + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2]) * model_management.dtype_size(dtype) #TODO: tweak for the audio VAE + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * 64) * model_management.dtype_size(dtype) + self.latent_channels = 64 + self.output_channels = 2 + self.upscale_ratio = 2048 + self.downscale_ratio = 2048 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -260,12 +272,12 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) def vae_encode_crop_pixels(self, pixels): - x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio - y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 - y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + dims = pixels.shape[1:-1] + for d in range(len(dims)): + x = (dims[d] // self.downscale_ratio) * self.downscale_ratio + x_offset = (dims[d] % self.downscale_ratio) // 2 + if x != dims[d]: + pixels = pixels.narrow(d + 1, x_offset, x) return pixels def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -303,7 +315,7 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) @@ -328,7 +340,7 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) + samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() @@ -371,6 +383,7 @@ class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 SD3 = 3 + STABLE_AUDIO = 4 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] @@ -404,6 +417,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.tokenizer = sd3_clip.SD3Tokenizer + elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: + clip_target.clip = sa_t5.SAT5Model + clip_target.tokenizer = sa_t5.SAT5Tokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -470,10 +486,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model_patcher = None clip_target = None - parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) + parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) @@ -488,8 +505,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) - model.load_model_weights(sd, "model.diffusion_model.") + model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) + model.load_model_weights(sd, diffusion_model_prefix) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c8ddf3e2..761498db 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -6,6 +6,7 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip from . import sd3_clip +from . import sa_t5 from . import supported_models_base from . import latent_formats @@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE): return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) +class StableAudio(supported_models_base.BASE): + unet_config = { + "audio_model": "dit1.0", + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] + sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03} + + unet_extra_config = {} + latent_format = latent_formats.StableAudio1 + + text_encoder_key_prefix = ["text_encoders."] + vae_key_prefix = ["pretransform.model."] + + def get_model(self, state_dict, prefix="", device=None): + seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True) + seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True) + return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device) + + + def process_unet_state_dict(self, state_dict): + for k in list(state_dict.keys()): + if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero + state_dict.pop(k) + return state_dict + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model) + + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py new file mode 100644 index 00000000..5f4bd354 --- /dev/null +++ b/comfy_extras/nodes_audio.py @@ -0,0 +1,128 @@ +import torchaudio +import torch +import comfy.model_management +import folder_paths +import os + +class EmptyLatentAudio: + def __init__(self): + self.device = comfy.model_management.intermediate_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": {}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "_for_testing/audio" + + def generate(self): + batch_size = 1 + latent = torch.zeros([batch_size, 64, 1024], device=self.device) + return ({"samples":latent, "type": "audio"}, ) + +class VAEEncodeAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/audio" + + def encode(self, vae, audio): + t = vae.encode(audio["waveform"].movedim(1, -1)) + return ({"samples":t}, ) + +class VAEDecodeAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + RETURN_TYPES = ("AUDIO",) + FUNCTION = "decode" + + CATEGORY = "_for_testing/audio" + + def decode(self, vae, samples): + audio = vae.decode(samples["samples"]).movedim(-1, 1) + return ({"waveform": audio, "sample_rate": 44100}, ) + +class SaveAudio: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + self.compress_level = 4 + + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio": ("AUDIO", ), + "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_audio" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing/audio" + + def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + results = list() + for (batch_number, waveform) in enumerate(audio["waveform"]): + #TODO: metadata + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.flac" + torchaudio.save(os.path.join(full_output_folder, file), waveform, audio["sample_rate"], format="FLAC") + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + return { "ui": { "audio": results } } + +class LoadAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + return {"required": {"audio": [sorted(files), ]}, } + + CATEGORY = "_for_testing/audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = torchaudio.load(audio_path) + multiplier = 1.0 + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + @classmethod + def IS_CHANGED(s, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + +NODE_CLASS_MAPPINGS = { + "EmptyLatentAudio": EmptyLatentAudio, + "VAEEncodeAudio": VAEEncodeAudio, + "VAEDecodeAudio": VAEDecodeAudio, + "SaveAudio": SaveAudio, + "LoadAudio": LoadAudio, +} diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 9bcd3c39..97559cf5 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -196,6 +196,36 @@ class ModelSamplingContinuousEDM: m.add_object_patch("latent_format", latent_format) return (m, ) +class ModelSamplingContinuousV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction"],), + "sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + latent_format = None + sigma_data = 1.0 + if sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class RescaleCFG: @classmethod def INPUT_TYPES(s): @@ -238,6 +268,7 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingContinuousV": ModelSamplingContinuousV, "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, "RescaleCFG": RescaleCFG, diff --git a/nodes.py b/nodes.py index 6fbeb377..0b2a96f7 100644 --- a/nodes.py +++ b/nodes.py @@ -818,7 +818,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -826,11 +826,14 @@ class CLIPLoader: CATEGORY = "advanced/loaders" def load_clip(self, clip_name, type="stable_diffusion"): - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION if type == "stable_cascade": clip_type = comfy.sd.CLIPType.STABLE_CASCADE elif type == "sd3": clip_type = comfy.sd.CLIPType.SD3 + elif type == "stable_audio": + clip_type = comfy.sd.CLIPType.STABLE_AUDIO + else: + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_path = folder_paths.get_full_path("clip", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) @@ -1973,6 +1976,7 @@ def init_custom_nodes(): "nodes_attention_multiply.py", "nodes_advanced_samplers.py", "nodes_webcam.py", + "nodes_audio.py", "nodes_sd3.py", ] diff --git a/requirements.txt b/requirements.txt index 8f681f8f..85e1dc9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ torch torchsde torchvision +torchaudio einops transformers>=4.25.1 safetensors>=0.4.2