Move cascade scale factor from stage_a to latent_formats.py
This commit is contained in:
parent
f2fe635c9f
commit
d7897fff2c
|
@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
|
||||||
|
|
||||||
class SC_B(LatentFormat):
|
class SC_B(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0 / 0.43
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[ 0.1121, 0.2006, 0.1023],
|
[ 0.1121, 0.2006, 0.1023],
|
||||||
[-0.2093, -0.0222, -0.0195],
|
[-0.2093, -0.0222, -0.0195],
|
||||||
|
|
|
@ -163,11 +163,9 @@ class ResBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class StageA(nn.Module):
|
class StageA(nn.Module):
|
||||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
||||||
scale_factor=0.43): # 0.3764
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_latent = c_latent
|
self.c_latent = c_latent
|
||||||
self.scale_factor = scale_factor
|
|
||||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||||
|
|
||||||
# Encoder blocks
|
# Encoder blocks
|
||||||
|
@ -214,12 +212,11 @@ class StageA(nn.Module):
|
||||||
x = self.down_blocks(x)
|
x = self.down_blocks(x)
|
||||||
if quantize:
|
if quantize:
|
||||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
return qe, x, indices, vq_loss + commit_loss * 0.25
|
||||||
else:
|
else:
|
||||||
return x / self.scale_factor
|
return x
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
x = x * self.scale_factor
|
|
||||||
x = self.up_blocks(x)
|
x = self.up_blocks(x)
|
||||||
x = self.out_block(x)
|
x = self.out_block(x)
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue