Clean up and remove modifying zero sigma

This commit is contained in:
chaObserv 2024-10-30 01:16:34 +08:00
parent 70ff03429c
commit c176ad8f50
1 changed files with 10 additions and 14 deletions

View File

@ -1095,10 +1095,6 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if len(sigmas) <= 1:
return x
if sigmas[-1] == 0:
sigmas = sigmas.clone()
sigmas[-1] = 0.001
extra_args = {} if extra_args is None else extra_args
if tau_func is None:
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@ -1115,7 +1111,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
for i in trange(len(sigmas) - 1, disable=disable):
sigma = sigmas[i]
if i == 0:
# Init the initial values.
# Init the initial values
denoised = model(x, sigma * s_in, **extra_args)
model_prev_list.append(denoised)
sigma_prev_list.append(sigma)
@ -1134,22 +1130,19 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Evaluation step
denoised = model(x_p, sigma * s_in, **extra_args)
# Update model_list
model_prev_list.append(denoised)
# Corrector step
if corrector_order_used > 0:
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)
noise=noise, sigma=sigma)
else:
x = x_p
del noise, x_p
# Evaluation step if mode = pece and step != steps
# Evaluation step for PECE
if corrector_order_used > 0 and pc_mode == 'PECE':
del model_prev_list[-1]
denoised = model(x, sigma * s_in, **extra_args)
@ -1163,10 +1156,13 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if callback is not None:
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
# Extra final step
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
if sigmas[-1] == 0:
# Denoising step
x = model_prev_list[-1]
else:
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
return x
@torch.no_grad()