Use torch.nn.functional.linear in RGB preview code.

Add an optional bias to the latent RGB preview code.
This commit is contained in:
comfyanonymous 2024-09-29 11:13:53 -04:00
parent 3bb4dec720
commit a9e459c2a4
2 changed files with 13 additions and 4 deletions

View File

@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
latent_channels = 4 latent_channels = 4
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None taesd_decoder_name = None
def process_in(self, latent): def process_in(self, latent):

View File

@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image) return preview_to_image(latent_image)
@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer return previewer
def prepare_callback(model, steps, x0_output_dict=None): def prepare_callback(model, steps, x0_output_dict=None):