Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
This commit is contained in:
parent
3bb4dec720
commit
a9e459c2a4
|
@ -4,6 +4,7 @@ class LatentFormat:
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
|
|
|
@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
if latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
if self.latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|
Loading…
Reference in New Issue