diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 7b54d8c5..ffdd888e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -164,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): + return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -181,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x +@torch.no_grad() +def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta + sigma_down = sigmas[i+1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i+1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + # Euler method + sigma_down_i_ratio = sigma_down / sigmas[i] + x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised + if sigmas[i + 1] > 0 and eta > 0: + x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + return x @torch.no_grad() def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):