from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ..diffusionmodules.openaimodel import Timestep import torch class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): super().__init__(*args, **kwargs) if clip_stats_path is None: clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) else: clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") self.register_buffer("data_mean", clip_mean[None, :], persistent=False) self.register_buffer("data_std", clip_std[None, :], persistent=False) self.time_embed = Timestep(timestep_dim) def scale(self, x): # re-normalize to centered mean and unit variance x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x def forward(self, x, noise_level=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) x = self.scale(x) z = self.q_sample(x, noise_level) z = self.unscale(z) noise_level = self.time_embed(noise_level) return z, noise_level