Implement support for taef1 latent previews (#4409)

* add taef1 handling to several places

* remove guess_latent_channels and add latent_channels info directly to flux model

* remove TODO

* fix numbers
This commit is contained in:
Matthew Turnshek 2024-08-16 12:53:13 -04:00 committed by GitHub
parent 05a9f3faa1
commit 1770fc77ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 1 deletions

View File

@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
latent_channels = 64 latent_channels = 64
class Flux(SD3): class Flux(SD3):
latent_channels = 16
def __init__(self): def __init__(self):
self.scale_factor = 0.3611 self.scale_factor = 0.3611
self.shift_factor = 0.1159 self.shift_factor = 0.1159
@ -162,6 +163,7 @@ class Flux(SD3):
[-0.0005, -0.0530, -0.0020], [-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680] [-0.1273, -0.0932, -0.0680]
] ]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent): def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor return (latent - self.shift_factor) * self.scale_factor

View File

@ -665,6 +665,8 @@ class VAELoader:
sd1_taesd_dec = False sd1_taesd_dec = False
sd3_taesd_enc = False sd3_taesd_enc = False
sd3_taesd_dec = False sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes: for v in approx_vaes:
if v.startswith("taesd_decoder."): if v.startswith("taesd_decoder."):
@ -679,12 +681,18 @@ class VAELoader:
sd3_taesd_dec = True sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."): elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc: if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd") vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc: if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl") vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc: if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3") vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
return vaes return vaes
@staticmethod @staticmethod
@ -712,6 +720,9 @@ class VAELoader:
elif name == "taesd3": elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305) sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609) sd["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
return sd return sd
@classmethod @classmethod
@ -724,7 +735,7 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3"]: if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)