Remove useless code.

This commit is contained in:
comfyanonymous 2023-06-13 02:40:58 -04:00
parent 274dff3257
commit 2b14041d4b
8 changed files with 0 additions and 870 deletions

View File

@ -1,105 +0,0 @@
from functools import reduce
import math
import operator
import numpy as np
from skimage import transform
import torch
from torch import nn
def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans
def __call__(self, image):
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond
class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
def set_skip_stages(self, skip_stages):
return self.inner_model.set_skip_stages(skip_stages)
def set_patch_size(self, patch_size):
return self.inner_model.set_patch_size(patch_size)

View File

@ -1,110 +0,0 @@
from functools import partial
import json
import math
import warnings
from jsonmerge import merge
from . import augmentation, layers, models, utils
def load_config(file):
defaults = {
'model': {
'sigma_data': 1.,
'patch_size': 1,
'dropout_rate': 0.,
'augment_wrapper': True,
'augment_prob': 0.,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
'has_variance': False,
},
'dataset': {
'type': 'imagefolder',
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
},
}
config = json.load(file)
return merge(defaults, config)
def make_model(config):
config = config['model']
assert config['type'] == 'image_v1'
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
has_variance=config['has_variance'],
)
if config['augment_wrapper']:
model = augmentation.KarrasAugmentWrapper(model)
return model
def make_denoiser_wrapper(config):
config = config['model']
sigma_data = config.get('sigma_data', 1.)
has_variance = config.get('has_variance', False)
if not has_variance:
return partial(layers.Denoiser, sigma_data=sigma_data)
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
def make_sample_density(config):
sd_config = config['sigma_sample_density']
sigma_data = config['sigma_data']
if sd_config['type'] == 'lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if sd_config['type'] == 'loglogistic':
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'loguniform':
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'v-diffusion':
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'split-lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
raise ValueError('Unknown sample density type')

View File

@ -1,134 +0,0 @@
import math
import os
from pathlib import Path
from cleanfid.inception_torchscript import InceptionV3W
import clip
from resize_right import resize
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm.auto import trange
from . import utils
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
self.size = (299, 299)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x, out_shape=self.size, pad_mode='reflect')
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = (x * 127.5 + 127.5).clamp(0, 255)
return self.model(x)
class CLIPFeatureExtractor(nn.Module):
def __init__(self, name='ViT-L/14@336px', device='cpu'):
super().__init__()
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
x = self.normalize(x)
x = self.model.encode_image(x).float()
x = F.normalize(x) * x.shape[1] ** 0.5
return x
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
n_per_proc = math.ceil(n / accelerator.num_processes)
feats_all = []
try:
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
cur_batch_size = min(n - i, batch_size)
samples = sample_fn(cur_batch_size)[:cur_batch_size]
feats_all.append(accelerator.gather(extractor_fn(samples)))
except StopIteration:
pass
return torch.cat(feats_all)[:n]
def polynomial_kernel(x, y):
d = x.shape[-1]
dot = x @ y.transpose(-2, -1)
return (dot / d + 1) ** 3
def squared_mmd(x, y, kernel=polynomial_kernel):
m = x.shape[-2]
n = y.shape[-2]
kxx = kernel(x, x)
kyy = kernel(y, y)
kxy = kernel(x, y)
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
kxy_sum = kxy.sum([-1, -2])
term_1 = kxx_sum / m / (m - 1)
term_2 = kyy_sum / n / (n - 1)
term_3 = kxy_sum * 2 / m / n
return term_1 + term_2 - term_3
@utils.tf32_mode(matmul=False)
def kid(x, y, max_size=5000):
x_size, y_size = x.shape[0], y.shape[0]
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
total_mmd = x.new_zeros([])
for i in range(n_partitions):
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
return total_mmd / n_partitions
class _MatrixSquareRootEig(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
vals, vecs = torch.linalg.eigh(a)
ctx.save_for_backward(vals, vecs)
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
@staticmethod
def backward(ctx, grad_output):
vals, vecs = ctx.saved_tensors
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
vecs_t = vecs.transpose(-2, -1)
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
def sqrtm_eig(a):
if a.ndim < 2:
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
if a.shape[-2] != a.shape[-1]:
raise RuntimeError('tensor must be batches of square matrices')
return _MatrixSquareRootEig.apply(a)
@utils.tf32_mode(matmul=False)
def fid(x, y, eps=1e-8):
x_mean = x.mean(dim=0)
y_mean = y.mean(dim=0)
mean_term = (x_mean - y_mean).pow(2).sum()
x_cov = torch.cov(x.T)
y_cov = torch.cov(y.T)
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
x_cov = x_cov + eps_eye
y_cov = y_cov + eps_eye
x_cov_sqrt = sqrtm_eig(x_cov)
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
return mean_term + cov_term

View File

@ -1,99 +0,0 @@
import torch
from torch import nn
class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
self._clear_state()
def _clear_state(self):
self.bucket_sq_norms_small_batch = []
self.bucket_sq_norms_large_batch = []
@staticmethod
def _hook_fn(self, bucket):
buf = bucket.buffer()
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
def callback(fut):
buf = fut.value()[0]
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
return buf
return fut.then(callback)
def get_stats(self):
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
self._clear_state()
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
return stats[0].item(), stats[1].item()
class GradientNoiseScale:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def __init__(self, beta=0.9998, eps=1e-8):
self.beta = beta
self.eps = eps
self.ema_sq_norm = 0.
self.ema_var = 0.
self.beta_cumprod = 1.
self.gradient_noise_scale = float('nan')
def state_dict(self):
"""Returns the state of the object as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
self.beta_cumprod *= self.beta
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
return self.gradient_noise_scale
def get_gns(self):
"""Returns the current gradient noise scale."""
return self.gradient_noise_scale
def get_stats(self):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)

View File

@ -1,246 +0,0 @@
import math
from einops import rearrange, repeat
import torch
from torch import nn
from torch.nn import functional as F
from . import utils
# Karras et al. preconditioned denoiser
class Denoiser(nn.Module):
"""A Karras et al. preconditioner for denoising diffusion models."""
def __init__(self, inner_model, sigma_data=1.):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
class DenoiserWithVariance(Denoiser):
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
logvar = utils.append_dims(logvar, model_output.ndim)
target = (input - c_skip * noised_input) / c_out
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
return losses.flatten(1).mean(1)
# Residual blocks
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
# Noise level (and other) conditioning
class ConditionedModule(nn.Module):
pass
class UnconditionedModule(ConditionedModule):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, input, cond=None):
return self.module(input)
class ConditionedSequential(nn.Sequential, ConditionedModule):
def forward(self, input, cond):
for module in self:
if isinstance(module, ConditionedModule):
input = module(input, cond)
else:
input = module(input)
return input
class ConditionedResidualBlock(ConditionedModule):
def __init__(self, *main, skip=None):
super().__init__()
self.main = ConditionedSequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input, cond):
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
return self.main(input, cond) + skip
class AdaGN(ConditionedModule):
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.cond_key = cond_key
self.mapper = nn.Linear(feats_in, c_out * 2)
def forward(self, input, cond):
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
input = F.group_norm(input, self.num_groups, eps=self.eps)
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
# Attention
class SelfAttention2d(ConditionedModule):
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input, cond))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class CrossAttention2d(ConditionedModule):
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
cond_key='cross', cond_key_padding='cross_padding'):
super().__init__()
assert c_dec % n_head == 0
self.cond_key = cond_key
self.cond_key_padding = cond_key_padding
self.norm_enc = nn.LayerNorm(c_enc)
self.norm_dec = norm_dec(c_dec)
self.n_head = n_head
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
q = self.q_proj(self.norm_dec(input, cond))
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
k, v = kv.chunk(2, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale))
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
att = att.softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3)
y = y.contiguous().view([n, c, h, w])
return input + self.out_proj(y)
# Downsampling/upsampling
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
# Embeddings
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
# U-Nets
class UNet(ConditionedModule):
def __init__(self, d_blocks, u_blocks, skip_stages=0):
super().__init__()
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(u_blocks)
self.skip_stages = skip_stages
def forward(self, input, cond):
skips = []
for block in self.d_blocks[self.skip_stages:]:
input = block(input, cond)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, cond, skip if i > 0 else None)
return input

View File

@ -1 +0,0 @@
from .image_v1 import ImageDenoiserModelV1

View File

@ -1,156 +0,0 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from .. import layers, utils
def orthogonal_(module):
nn.init.orthogonal_(module.weight)
return module
class ResConvBlock(layers.ConditionedResidualBlock):
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
super().__init__(
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
skip=skip)
class DBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = layers.Downsample2d() if downsample else nn.Identity()
return self
class UBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, cond, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)
def set_upsample(self, upsample):
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
return self
class MappingNet(nn.Sequential):
def __init__(self, feats_in, feats_out, n_layers=2):
layers = []
for i in range(n_layers):
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
layers.append(nn.GELU())
super().__init__(*layers)
class ImageDenoiserModelV1(nn.Module):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
super().__init__()
self.c_in = c_in
self.channels = channels
self.unet_cond_dim = unet_cond_dim
self.patch_size = patch_size
self.has_variance = has_variance
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
if cross_cond_dim == 0:
cross_attn_depths = [False] * len(self_attn_depths)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {'cond': mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
if cross_cond is not None:
cond['cross'] = cross_cond
cond['cross_padding'] = cross_cond_padding
if self.patch_size > 1:
input = F.pixel_unshuffle(input, self.patch_size)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
if self.has_variance:
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
if self.patch_size > 1:
input = F.pixel_shuffle(input, self.patch_size)
if self.has_variance and return_variance:
return input, logvar
return input
def set_skip_stages(self, skip_stages):
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
self.u_net.skip_stages = skip_stages
for i, block in enumerate(self.u_net.d_blocks):
block.set_downsample(i > skip_stages)
for i, block in enumerate(reversed(self.u_net.u_blocks)):
block.set_upsample(i > skip_stages)
return self
def set_patch_size(self, patch_size):
self.patch_size = patch_size
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)

View File

@ -10,25 +10,6 @@ from PIL import Image
import torch
from torch import nn, optim
from torch.utils import data
from torchvision.transforms import functional as TF
def from_pil_image(x):
"""Converts from a PIL image to a tensor."""
x = TF.to_tensor(x)
if x.ndim == 2:
x = x[..., None]
return x * 2 - 1
def to_pil_image(x):
"""Converts from a tensor to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):