From 002aefa382585d171aef13c7bd21f64b8664fe28 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 9 Nov 2023 17:57:51 -0500 Subject: [PATCH] Support lcm models. Use the "lcm" sampler to sample them, you also have to use the ModelSamplingDiscrete node to set them as lcm models to use them properly. --- comfy/k_diffusion/sampling.py | 15 +++++- comfy/samplers.py | 2 +- comfy_extras/nodes_model_advanced.py | 75 +++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 937c5a38..dd6f7bbe 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) return mu - def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) +@torch.no_grad() +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index 964febb2..d7ff8985 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -519,7 +519,7 @@ class UNIPCBH2(Sampler): KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] def ksampler(sampler_name, extra_options={}, inpaint_options={}): class KSAMPLER(Sampler): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index c02cfb05..42596fbd 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,72 @@ import folder_paths import comfy.sd import comfy.model_sampling +import torch + +class LCM(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + x0 = model_input - model_output * sigma + + sigma_data = 0.5 + scaled_timestep = timestep * 10.0 #timestep_scaling + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * x0 + c_skip * model_input + +class ModelSamplingDiscreteLCM(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigma_data = 1.0 + timesteps = 1000 + beta_start = 0.00085 + beta_end = 0.012 + + betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + original_timesteps = 50 + self.skip_steps = timesteps // original_timesteps + + + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) + for x in range(original_timesteps): + alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 + self.set_sigmas(sigmas) + + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + + def sigma(self, timestep): + t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + def percent_to_sigma(self, percent): + return self.sigma(torch.tensor(percent * 999.0)) def rescale_zero_terminal_snr_sigmas(sigmas): @@ -26,7 +92,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction"],), + "sampling": (["eps", "v_prediction", "lcm"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -38,17 +104,22 @@ class ModelSamplingDiscrete: def patch(self, model, sampling, zsnr): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingDiscrete if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteLCM - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced() if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + m.add_object_patch("model_sampling", model_sampling) return (m, )