Support lcm models.
Use the "lcm" sampler to sample them, you also have to use the ModelSamplingDiscrete node to set them as lcm models to use them properly.
This commit is contained in:
parent
ca71e542d2
commit
002aefa382
|
@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
|
||||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
||||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
x = denoised
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||||
|
return x
|
||||||
|
|
|
@ -519,7 +519,7 @@ class UNIPCBH2(Sampler):
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
|
|
||||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
|
|
|
@ -1,6 +1,72 @@
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LCM(comfy.model_sampling.EPS):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
x0 = model_input - model_output * sigma
|
||||||
|
|
||||||
|
sigma_data = 0.5
|
||||||
|
scaled_timestep = timestep * 10.0 #timestep_scaling
|
||||||
|
|
||||||
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||||
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||||
|
|
||||||
|
return c_out * x0 + c_skip * model_input
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma_data = 1.0
|
||||||
|
timesteps = 1000
|
||||||
|
beta_start = 0.00085
|
||||||
|
beta_end = 0.012
|
||||||
|
|
||||||
|
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
original_timesteps = 50
|
||||||
|
self.skip_steps = timesteps // original_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||||
|
for x in range(original_timesteps):
|
||||||
|
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
|
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
|
def set_sigmas(self, sigmas):
|
||||||
|
self.register_buffer('sigmas', sigmas)
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
|
low_idx = t.floor().long()
|
||||||
|
high_idx = t.ceil().long()
|
||||||
|
w = t.frac()
|
||||||
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
return self.sigma(torch.tensor(percent * 999.0))
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["eps", "v_prediction"],),
|
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||||
"zsnr": ("BOOLEAN", {"default": False}),
|
"zsnr": ("BOOLEAN", {"default": False}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
@ -38,17 +104,22 @@ class ModelSamplingDiscrete:
|
||||||
def patch(self, model, sampling, zsnr):
|
def patch(self, model, sampling, zsnr):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
|
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
|
||||||
if sampling == "eps":
|
if sampling == "eps":
|
||||||
sampling_type = comfy.model_sampling.EPS
|
sampling_type = comfy.model_sampling.EPS
|
||||||
elif sampling == "v_prediction":
|
elif sampling == "v_prediction":
|
||||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
|
elif sampling == "lcm":
|
||||||
|
sampling_type = LCM
|
||||||
|
sampling_base = ModelSamplingDiscreteLCM
|
||||||
|
|
||||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced()
|
model_sampling = ModelSamplingAdvanced()
|
||||||
if zsnr:
|
if zsnr:
|
||||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||||
|
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue