From 6ab8cad22ee376c0a9a506924cca278bb650afcc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 Jul 2024 17:44:56 -0400 Subject: [PATCH] Implement beta sampling scheduler. It is based on: https://arxiv.org/abs/2407.12173 Add "beta" to the list of schedulers and the BetaSamplingScheduler node. --- comfy/samplers.py | 18 +++++++++++++++++- comfy_extras/nodes_custom_sampler.py | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index bbf9219f..0ab08a4c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,8 @@ from comfy import model_management import math import logging import comfy.sampler_helpers +import scipy +import numpy def get_area_and_mult(conds, x_in, timestep_in): dims = tuple(x_in.shape[2:]) @@ -337,6 +339,18 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False): sigs += [0.0] return torch.FloatTensor(sigs) +# Implemented based on: https://arxiv.org/abs/2407.12173 +def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6): + total_timesteps = (len(model_sampling.sigmas) - 1) + ts = 1 - numpy.linspace(0, 1, steps, endpoint=False) + ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps) + + sigs = [] + for t in ts: + sigs += [float(model_sampling.sigmas[int(t)])] + sigs += [0.0] + return torch.FloatTensor(sigs) + def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -703,7 +717,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def calculate_sigmas(model_sampling, scheduler_name, steps): @@ -719,6 +733,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps): sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model_sampling, steps, sgm=True) + elif scheduler_name == "beta": + sigmas = beta_scheduler(model_sampling, steps) else: logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 64a8c063..b7ab88c2 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -111,6 +111,25 @@ class SDTurboScheduler: sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) +class BetaSamplingScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), + "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps, alpha, beta): + sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) + return (sigmas, ) + class VPScheduler: @classmethod def INPUT_TYPES(s): @@ -638,6 +657,7 @@ NODE_CLASS_MAPPINGS = { "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, "VPScheduler": VPScheduler, + "BetaSamplingScheduler": BetaSamplingScheduler, "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerEulerAncestral": SamplerEulerAncestral,