Move cascade scale factor from stage_a to latent_formats.py

This commit is contained in:
comfyanonymous 2024-03-16 14:49:35 -04:00
parent f2fe635c9f
commit d7897fff2c
2 changed files with 4 additions and 7 deletions

View File

@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],

View File

@ -163,11 +163,9 @@ class ResBlock(nn.Module):
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
scale_factor=0.43): # 0.3764
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
super().__init__()
self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks
@ -214,12 +212,11 @@ class StageA(nn.Module):
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
return qe, x, indices, vq_loss + commit_loss * 0.25
else:
return x / self.scale_factor
return x
def decode(self, x):
x = x * self.scale_factor
x = self.up_blocks(x)
x = self.out_block(x)
return x