ModelSamplingDiscreteLCM -> ModelSamplingDiscreteDistilled
This commit is contained in:
parent
13fdee6abf
commit
488de0b4df
|
@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
|
|||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
|
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
|
|||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
original_timesteps = 50
|
||||
self.skip_steps = timesteps // original_timesteps
|
||||
self.skip_steps = timesteps // self.original_timesteps
|
||||
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||
for x in range(original_timesteps):
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
for x in range(self.original_timesteps):
|
||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
@ -116,7 +117,7 @@ class ModelSamplingDiscrete:
|
|||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteLCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue