Implement beta sampling scheduler.
It is based on: https://arxiv.org/abs/2407.12173 Add "beta" to the list of schedulers and the BetaSamplingScheduler node.
This commit is contained in:
parent
011b11d8d7
commit
6ab8cad22e
|
@ -6,6 +6,8 @@ from comfy import model_management
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
|
import scipy
|
||||||
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
dims = tuple(x_in.shape[2:])
|
dims = tuple(x_in.shape[2:])
|
||||||
|
@ -337,6 +339,18 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
# Implemented based on: https://arxiv.org/abs/2407.12173
|
||||||
|
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||||
|
total_timesteps = (len(model_sampling.sigmas) - 1)
|
||||||
|
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
||||||
|
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||||
|
|
||||||
|
sigs = []
|
||||||
|
for t in ts:
|
||||||
|
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
|
@ -703,7 +717,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
|
@ -719,6 +733,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
sigmas = ddim_scheduler(model_sampling, steps)
|
sigmas = ddim_scheduler(model_sampling, steps)
|
||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
|
elif scheduler_name == "beta":
|
||||||
|
sigmas = beta_scheduler(model_sampling, steps)
|
||||||
else:
|
else:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
|
@ -111,6 +111,25 @@ class SDTurboScheduler:
|
||||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class BetaSamplingScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, model, steps, alpha, beta):
|
||||||
|
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
class VPScheduler:
|
class VPScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
@ -638,6 +657,7 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
"KSamplerSelect": KSamplerSelect,
|
"KSamplerSelect": KSamplerSelect,
|
||||||
"SamplerEulerAncestral": SamplerEulerAncestral,
|
"SamplerEulerAncestral": SamplerEulerAncestral,
|
||||||
|
|
Loading…
Reference in New Issue