Implement support for taef1 latent previews (#4409)
* add taef1 handling to several places * remove guess_latent_channels and add latent_channels info directly to flux model * remove TODO * fix numbers
This commit is contained in:
parent
05a9f3faa1
commit
1770fc77ed
|
@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
|
||||||
class Flux(SD3):
|
class Flux(SD3):
|
||||||
|
latent_channels = 16
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
|
@ -162,6 +163,7 @@ class Flux(SD3):
|
||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0005, -0.0530, -0.0020],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1273, -0.0932, -0.0680]
|
||||||
]
|
]
|
||||||
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return (latent - self.shift_factor) * self.scale_factor
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|
13
nodes.py
13
nodes.py
|
@ -665,6 +665,8 @@ class VAELoader:
|
||||||
sd1_taesd_dec = False
|
sd1_taesd_dec = False
|
||||||
sd3_taesd_enc = False
|
sd3_taesd_enc = False
|
||||||
sd3_taesd_dec = False
|
sd3_taesd_dec = False
|
||||||
|
f1_taesd_enc = False
|
||||||
|
f1_taesd_dec = False
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
if v.startswith("taesd_decoder."):
|
||||||
|
@ -679,12 +681,18 @@ class VAELoader:
|
||||||
sd3_taesd_dec = True
|
sd3_taesd_dec = True
|
||||||
elif v.startswith("taesd3_encoder."):
|
elif v.startswith("taesd3_encoder."):
|
||||||
sd3_taesd_enc = True
|
sd3_taesd_enc = True
|
||||||
|
elif v.startswith("taef1_encoder."):
|
||||||
|
f1_taesd_dec = True
|
||||||
|
elif v.startswith("taef1_decoder."):
|
||||||
|
f1_taesd_enc = True
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
vaes.append("taesdxl")
|
vaes.append("taesdxl")
|
||||||
if sd3_taesd_dec and sd3_taesd_enc:
|
if sd3_taesd_dec and sd3_taesd_enc:
|
||||||
vaes.append("taesd3")
|
vaes.append("taesd3")
|
||||||
|
if f1_taesd_dec and f1_taesd_enc:
|
||||||
|
vaes.append("taef1")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -712,6 +720,9 @@ class VAELoader:
|
||||||
elif name == "taesd3":
|
elif name == "taesd3":
|
||||||
sd["vae_scale"] = torch.tensor(1.5305)
|
sd["vae_scale"] = torch.tensor(1.5305)
|
||||||
sd["vae_shift"] = torch.tensor(0.0609)
|
sd["vae_shift"] = torch.tensor(0.0609)
|
||||||
|
elif name == "taef1":
|
||||||
|
sd["vae_scale"] = torch.tensor(0.3611)
|
||||||
|
sd["vae_shift"] = torch.tensor(0.1159)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -724,7 +735,7 @@ class VAELoader:
|
||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||||
|
|
Loading…
Reference in New Issue