Support lcm models.

Use the "lcm" sampler to sample them, you also have to use the
ModelSamplingDiscrete node to set them as lcm models to use them properly.
This commit is contained in:
comfyanonymous 2023-11-09 17:57:51 -05:00
parent ca71e542d2
commit 002aefa382
3 changed files with 88 additions and 4 deletions

View File

@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x

View File

@ -519,7 +519,7 @@ class UNIPCBH2(Sampler):
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}, inpaint_options={}): def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler): class KSAMPLER(Sampler):

View File

@ -1,6 +1,72 @@
import folder_paths import folder_paths
import comfy.sd import comfy.sd
import comfy.model_sampling import comfy.model_sampling
import torch
class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma
sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50
self.skip_steps = timesteps // original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],), "sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}), "zsnr": ("BOOLEAN", {"default": False}),
}} }}
@ -38,17 +104,22 @@ class ModelSamplingDiscrete:
def patch(self, model, sampling, zsnr): def patch(self, model, sampling, zsnr):
m = model.clone() m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
if sampling == "eps": if sampling == "eps":
sampling_type = comfy.model_sampling.EPS sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction": elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced()
if zsnr: if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )