Some optimizations to euler a.
This commit is contained in:
parent
cdc3b97dd5
commit
c1b92b719d
|
@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], denoised)
|
|
||||||
# Euler method
|
if sigma_down == 0:
|
||||||
dt = sigma_down - sigmas[i]
|
x = denoised
|
||||||
x = x + d * dt
|
else:
|
||||||
if sigmas[i + 1] > 0:
|
d = to_d(x, sigmas[i], denoised)
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
# Euler method
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
|
||||||
sigma_down = sigmas[i+1] * downstep_ratio
|
|
||||||
alpha_ip1 = 1 - sigmas[i+1]
|
|
||||||
alpha_down = 1 - sigma_down
|
|
||||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
# Euler method
|
if sigmas[i + 1] == 0:
|
||||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
x = denoised
|
||||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
else:
|
||||||
if sigmas[i + 1] > 0 and eta > 0:
|
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||||
|
alpha_ip1 = 1 - sigmas[i + 1]
|
||||||
|
alpha_down = 1 - sigma_down
|
||||||
|
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||||
|
# Euler method
|
||||||
|
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||||
|
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||||
|
if eta > 0:
|
||||||
|
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue