From 1770fc77ed91348f4060e3c0b040c1519d6f91d0 Mon Sep 17 00:00:00 2001 From: Matthew Turnshek Date: Fri, 16 Aug 2024 12:53:13 -0400 Subject: [PATCH] Implement support for taef1 latent previews (#4409) * add taef1 handling to several places * remove guess_latent_channels and add latent_channels info directly to flux model * remove TODO * fix numbers --- comfy/latent_formats.py | 2 ++ nodes.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index ecb03b01..ee19faea 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -141,6 +141,7 @@ class StableAudio1(LatentFormat): latent_channels = 64 class Flux(SD3): + latent_channels = 16 def __init__(self): self.scale_factor = 0.3611 self.shift_factor = 0.1159 @@ -162,6 +163,7 @@ class Flux(SD3): [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680] ] + self.taesd_decoder_name = "taef1_decoder" def process_in(self, latent): return (latent - self.shift_factor) * self.scale_factor diff --git a/nodes.py b/nodes.py index 16f5c9b0..b817a865 100644 --- a/nodes.py +++ b/nodes.py @@ -665,6 +665,8 @@ class VAELoader: sd1_taesd_dec = False sd3_taesd_enc = False sd3_taesd_dec = False + f1_taesd_enc = False + f1_taesd_dec = False for v in approx_vaes: if v.startswith("taesd_decoder."): @@ -679,12 +681,18 @@ class VAELoader: sd3_taesd_dec = True elif v.startswith("taesd3_encoder."): sd3_taesd_enc = True + elif v.startswith("taef1_encoder."): + f1_taesd_dec = True + elif v.startswith("taef1_decoder."): + f1_taesd_enc = True if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: vaes.append("taesdxl") if sd3_taesd_dec and sd3_taesd_enc: vaes.append("taesd3") + if f1_taesd_dec and f1_taesd_enc: + vaes.append("taef1") return vaes @staticmethod @@ -712,6 +720,9 @@ class VAELoader: elif name == "taesd3": sd["vae_scale"] = torch.tensor(1.5305) sd["vae_shift"] = torch.tensor(0.0609) + elif name == "taef1": + sd["vae_scale"] = torch.tensor(0.3611) + sd["vae_shift"] = torch.tensor(0.1159) return sd @classmethod @@ -724,7 +735,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl", "taesd3"]: + if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path("vae", vae_name)