Add callback to sampler function.
Callback format is: callback(step, x0, x)
This commit is contained in:
parent
3a1f9dba20
commit
5a971cecdb
|
@ -712,7 +712,7 @@ class UniPC:
|
|||
|
||||
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False,
|
||||
atol=0.0078, rtol=0.05, corrector=False, callback=None
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
|
@ -766,6 +766,8 @@ class UniPC:
|
|||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if callback is not None:
|
||||
callback(step_index, model_prev_list[-1], x)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
|
@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||
|
||||
order = min(3, len(timesteps) - 1)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback)
|
||||
if not to_zero:
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
return x
|
||||
|
|
|
@ -56,7 +56,7 @@ def cleanup_additional_models(models):
|
|||
for m in models:
|
||||
m.cleanup()
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None):
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
if noise_mask is not None:
|
||||
|
@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas)
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
|
||||
samples = samples.cpu()
|
||||
|
||||
cleanup_additional_models(models)
|
||||
|
|
|
@ -462,7 +462,7 @@ class KSampler:
|
|||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None):
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
|
||||
if sigmas is None:
|
||||
sigmas = self.sigmas
|
||||
sigma_min = self.sigma_min
|
||||
|
@ -527,9 +527,9 @@ class KSampler:
|
|||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2')
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
|
@ -537,6 +537,11 @@ class KSampler:
|
|||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None)
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
|
@ -550,6 +555,7 @@ class KSampler:
|
|||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
|
@ -563,13 +569,17 @@ class KSampler:
|
|||
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"])
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args)
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
|
||||
|
||||
return samples.to(torch.float32)
|
||||
|
|
Loading…
Reference in New Issue