diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index ee19faea..78397d75 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,6 +4,7 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_rgb_factors = None + latent_rgb_factors_bias = None taesd_decoder_name = None def process_in(self, latent): diff --git a/latent_preview.py b/latent_preview.py index e14c72ce..ae9211a2 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors): - self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) + self.latent_rgb_factors_bias = None + if latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") def decode_latent_to_preview(self, x0): self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) - latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors + if self.latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) + + latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias) + # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors + return preview_to_image(latent_image) @@ -71,7 +79,7 @@ def get_previewer(device, latent_format): if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) return previewer def prepare_callback(model, steps, x0_output_dict=None):