Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
This commit is contained in:
parent
3bb4dec720
commit
a9e459c2a4
|
@ -4,6 +4,7 @@ class LatentFormat:
|
|||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
|
|
|
@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self, latent_rgb_factors):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = None
|
||||
if latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
|
||||
|
@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
|||
|
||||
if previewer is None:
|
||||
if latent_format.latent_rgb_factors is not None:
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||
return previewer
|
||||
|
||||
def prepare_callback(model, steps, x0_output_dict=None):
|
||||
|
|
Loading…
Reference in New Issue