diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index e9e4edcc..70273d9d 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1154,3 +1154,36 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + +@torch.no_grad() +def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + t_fn = lambda sigma: sigma.log().neg() + + old_uncond_denoised = None + uncond_denoised = None + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] + + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_uncond_denoised is None or sigmas[i + 1] == 0: + denoised_mix = -torch.exp(-h) * uncond_denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised) + x = denoised + denoised_mix + torch.exp(-h) * x + old_uncond_denoised = uncond_denoised + return x \ No newline at end of file diff --git a/comfy/samplers.py b/comfy/samplers.py index a20f65d4..1ecb41dd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -571,7 +571,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", + "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis"] class KSAMPLER(Sampler):