2024-06-15 16:14:56 +00:00
|
|
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from typing import Literal, Dict, Any
|
|
|
|
import math
|
|
|
|
import comfy.ops
|
|
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
|
|
|
|
def vae_sample(mean, scale):
|
|
|
|
stdev = nn.functional.softplus(scale) + 1e-4
|
|
|
|
var = stdev * stdev
|
|
|
|
logvar = torch.log(var)
|
|
|
|
latents = torch.randn_like(mean) * stdev + mean
|
|
|
|
|
|
|
|
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
|
|
|
|
|
|
|
return latents, kl
|
|
|
|
|
|
|
|
class VAEBottleneck(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.is_discrete = False
|
|
|
|
|
|
|
|
def encode(self, x, return_info=False, **kwargs):
|
|
|
|
info = {}
|
|
|
|
|
|
|
|
mean, scale = x.chunk(2, dim=1)
|
|
|
|
|
|
|
|
x, kl = vae_sample(mean, scale)
|
|
|
|
|
|
|
|
info["kl"] = kl
|
|
|
|
|
|
|
|
if return_info:
|
|
|
|
return x, info
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
def decode(self, x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def snake_beta(x, alpha, beta):
|
|
|
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
|
|
|
|
|
|
|
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
|
|
|
class SnakeBeta(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
|
|
|
super(SnakeBeta, self).__init__()
|
|
|
|
self.in_features = in_features
|
|
|
|
|
|
|
|
# initialize alpha
|
|
|
|
self.alpha_logscale = alpha_logscale
|
|
|
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
|
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
|
|
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
|
|
|
else: # linear scale alphas initialized to ones
|
|
|
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
|
|
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
|
|
|
|
|
|
|
# self.alpha.requires_grad = alpha_trainable
|
|
|
|
# self.beta.requires_grad = alpha_trainable
|
|
|
|
|
|
|
|
self.no_div_by_zero = 0.000000001
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
|
|
|
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
|
|
|
if self.alpha_logscale:
|
|
|
|
alpha = torch.exp(alpha)
|
|
|
|
beta = torch.exp(beta)
|
|
|
|
x = snake_beta(x, alpha, beta)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def WNConv1d(*args, **kwargs):
|
2024-06-16 17:06:23 +00:00
|
|
|
try:
|
|
|
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
|
|
|
except:
|
|
|
|
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
2024-06-15 16:14:56 +00:00
|
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs):
|
2024-06-16 17:06:23 +00:00
|
|
|
try:
|
|
|
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
|
|
|
except:
|
|
|
|
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
2024-06-15 16:14:56 +00:00
|
|
|
|
|
|
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
|
|
|
if activation == "elu":
|
|
|
|
act = torch.nn.ELU()
|
|
|
|
elif activation == "snake":
|
|
|
|
act = SnakeBeta(channels)
|
|
|
|
elif activation == "none":
|
|
|
|
act = torch.nn.Identity()
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown activation {activation}")
|
|
|
|
|
|
|
|
if antialias:
|
|
|
|
act = Activation1d(act)
|
|
|
|
|
|
|
|
return act
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualUnit(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
|
|
|
padding = (dilation * (7-1)) // 2
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
|
|
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
|
|
|
kernel_size=7, dilation=dilation, padding=padding),
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
|
|
|
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
|
|
|
kernel_size=1)
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
res = x
|
|
|
|
|
|
|
|
#x = checkpoint(self.layers, x)
|
|
|
|
x = self.layers(x)
|
|
|
|
|
|
|
|
return x + res
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(
|
|
|
|
ResidualUnit(in_channels=in_channels,
|
|
|
|
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
|
|
|
ResidualUnit(in_channels=in_channels,
|
|
|
|
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
|
|
|
ResidualUnit(in_channels=in_channels,
|
|
|
|
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
|
|
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
|
|
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
if use_nearest_upsample:
|
|
|
|
upsample_layer = nn.Sequential(
|
|
|
|
nn.Upsample(scale_factor=stride, mode="nearest"),
|
|
|
|
WNConv1d(in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=2*stride,
|
|
|
|
stride=1,
|
|
|
|
bias=False,
|
|
|
|
padding='same')
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
|
|
|
upsample_layer,
|
|
|
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
|
|
dilation=1, use_snake=use_snake),
|
|
|
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
|
|
dilation=3, use_snake=use_snake),
|
|
|
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
|
|
dilation=9, use_snake=use_snake),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
|
|
class OobleckEncoder(nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
in_channels=2,
|
|
|
|
channels=128,
|
|
|
|
latent_dim=32,
|
|
|
|
c_mults = [1, 2, 4, 8],
|
|
|
|
strides = [2, 4, 8, 8],
|
|
|
|
use_snake=False,
|
|
|
|
antialias_activation=False
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
c_mults = [1] + c_mults
|
|
|
|
|
|
|
|
self.depth = len(c_mults)
|
|
|
|
|
|
|
|
layers = [
|
|
|
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
|
|
|
]
|
|
|
|
|
|
|
|
for i in range(self.depth-1):
|
|
|
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
|
|
|
|
|
|
|
layers += [
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
|
|
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
|
|
|
]
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
|
|
|
|
|
|
class OobleckDecoder(nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
out_channels=2,
|
|
|
|
channels=128,
|
|
|
|
latent_dim=32,
|
|
|
|
c_mults = [1, 2, 4, 8],
|
|
|
|
strides = [2, 4, 8, 8],
|
|
|
|
use_snake=False,
|
|
|
|
antialias_activation=False,
|
|
|
|
use_nearest_upsample=False,
|
|
|
|
final_tanh=True):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
c_mults = [1] + c_mults
|
|
|
|
|
|
|
|
self.depth = len(c_mults)
|
|
|
|
|
|
|
|
layers = [
|
|
|
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
|
|
|
]
|
|
|
|
|
|
|
|
for i in range(self.depth-1, 0, -1):
|
|
|
|
layers += [DecoderBlock(
|
|
|
|
in_channels=c_mults[i]*channels,
|
|
|
|
out_channels=c_mults[i-1]*channels,
|
|
|
|
stride=strides[i-1],
|
|
|
|
use_snake=use_snake,
|
|
|
|
antialias_activation=antialias_activation,
|
|
|
|
use_nearest_upsample=use_nearest_upsample
|
|
|
|
)
|
|
|
|
]
|
|
|
|
|
|
|
|
layers += [
|
|
|
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
|
|
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
|
|
|
nn.Tanh() if final_tanh else nn.Identity()
|
|
|
|
]
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
|
|
|
|
|
|
class AudioOobleckVAE(nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
in_channels=2,
|
|
|
|
channels=128,
|
|
|
|
latent_dim=64,
|
|
|
|
c_mults = [1, 2, 4, 8, 16],
|
|
|
|
strides = [2, 4, 4, 8, 8],
|
|
|
|
use_snake=True,
|
|
|
|
antialias_activation=False,
|
|
|
|
use_nearest_upsample=False,
|
|
|
|
final_tanh=False):
|
|
|
|
super().__init__()
|
|
|
|
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
|
|
|
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
|
|
|
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
|
|
|
self.bottleneck = VAEBottleneck()
|
|
|
|
|
|
|
|
def encode(self, x):
|
|
|
|
return self.bottleneck.encode(self.encoder(x))
|
|
|
|
|
|
|
|
def decode(self, x):
|
|
|
|
return self.decoder(self.bottleneck.decode(x))
|
|
|
|
|