36 lines
1.5 KiB
Python
36 lines
1.5 KiB
Python
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
|
from ..diffusionmodules.openaimodel import Timestep
|
|
import torch
|
|
|
|
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
|
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if clip_stats_path is None:
|
|
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
|
else:
|
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
|
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
|
self.time_embed = Timestep(timestep_dim)
|
|
|
|
def scale(self, x):
|
|
# re-normalize to centered mean and unit variance
|
|
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
|
return x
|
|
|
|
def unscale(self, x):
|
|
# back to original data stats
|
|
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
|
return x
|
|
|
|
def forward(self, x, noise_level=None):
|
|
if noise_level is None:
|
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
|
else:
|
|
assert isinstance(noise_level, torch.Tensor)
|
|
x = self.scale(x)
|
|
z = self.q_sample(x, noise_level)
|
|
z = self.unscale(z)
|
|
noise_level = self.time_embed(noise_level)
|
|
return z, noise_level
|