Switch to the real cfg++ method in the samplers.

The old _pp ones will be updated automatically to the regular ones with 2x
the cfg.

My fault for not checking what the "_pp" samplers actually did.
This commit is contained in:
comfyanonymous 2024-06-29 11:59:48 -04:00
parent fbb7a1f1b6
commit 05e831697a
3 changed files with 15 additions and 7 deletions

View File

@ -998,7 +998,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
@ -1013,16 +1013,16 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x - denoised + temp[0], sigma_hat, denoised)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad()
def sample_euler_ancestral_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -1041,10 +1041,10 @@ def sample_euler_ancestral_pp(model, x, sigmas, extra_args=None, callback=None,
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x - denoised + temp[0], sigmas[i], denoised)
d = to_d(x, sigmas[i], temp[0])
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x

View File

@ -537,7 +537,7 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_pp", "euler_ancestral", "euler_ancestral_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]

View File

@ -1966,6 +1966,14 @@ export class ComfyApp {
if (widget.value.startsWith("sample_")) {
widget.value = widget.value.slice(7);
}
if (widget.value === "euler_pp" || widget.value === "euler_ancestral_pp") {
widget.value = widget.value.slice(0, -3);
for (let w of node.widgets) {
if (w.name == "cfg") {
w.value *= 2.0;
}
}
}
}
}
if (node.type == "KSampler" || node.type == "KSamplerAdvanced" || node.type == "PrimitiveNode") {