Move cascade scale factor from stage_a to latent_formats.py
This commit is contained in:
parent
f2fe635c9f
commit
d7897fff2c
|
@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
|
|||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
[ 0.1121, 0.2006, 0.1023],
|
||||
[-0.2093, -0.0222, -0.0195],
|
||||
|
|
|
@ -163,11 +163,9 @@ class ResBlock(nn.Module):
|
|||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
||||
scale_factor=0.43): # 0.3764
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
self.scale_factor = scale_factor
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
|
@ -214,12 +212,11 @@ class StageA(nn.Module):
|
|||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
||||
return qe, x, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x / self.scale_factor
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
x = x * self.scale_factor
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue