298 lines
11 KiB
Python
298 lines
11 KiB
Python
import comfy.samplers
|
|
import comfy.sample
|
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
|
import latent_preview
|
|
import torch
|
|
import comfy.utils
|
|
|
|
|
|
class BasicScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"model": ("MODEL",),
|
|
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
|
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, model, scheduler, steps, denoise):
|
|
total_steps = steps
|
|
if denoise < 1.0:
|
|
total_steps = int(steps/denoise)
|
|
|
|
inner_model = model.patch_model(patch_weights=False)
|
|
sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
|
|
model.unpatch_model()
|
|
sigmas = sigmas[-(steps + 1):]
|
|
return (sigmas, )
|
|
|
|
|
|
class KarrasScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, steps, sigma_max, sigma_min, rho):
|
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
|
return (sigmas, )
|
|
|
|
class ExponentialScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, steps, sigma_max, sigma_min):
|
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
|
return (sigmas, )
|
|
|
|
class PolyexponentialScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, steps, sigma_max, sigma_min, rho):
|
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
|
return (sigmas, )
|
|
|
|
class SDTurboScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"model": ("MODEL",),
|
|
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, model, steps, denoise):
|
|
start_step = 10 - int(10 * denoise)
|
|
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
|
inner_model = model.patch_model(patch_weights=False)
|
|
sigmas = inner_model.model_sampling.sigma(timesteps)
|
|
model.unpatch_model()
|
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
|
return (sigmas, )
|
|
|
|
class VPScheduler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
|
|
"beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
|
"eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, steps, beta_d, beta_min, eps_s):
|
|
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
|
|
return (sigmas, )
|
|
|
|
class SplitSigmas:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"sigmas": ("SIGMAS", ),
|
|
"step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, sigmas, step):
|
|
sigmas1 = sigmas[:step + 1]
|
|
sigmas2 = sigmas[step:]
|
|
return (sigmas1, sigmas2)
|
|
|
|
class FlipSigmas:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"sigmas": ("SIGMAS", ),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SIGMAS",)
|
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
|
|
|
FUNCTION = "get_sigmas"
|
|
|
|
def get_sigmas(self, sigmas):
|
|
sigmas = sigmas.flip(0)
|
|
if sigmas[0] == 0:
|
|
sigmas[0] = 0.0001
|
|
return (sigmas,)
|
|
|
|
class KSamplerSelect:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SAMPLER",)
|
|
CATEGORY = "sampling/custom_sampling/samplers"
|
|
|
|
FUNCTION = "get_sampler"
|
|
|
|
def get_sampler(self, sampler_name):
|
|
sampler = comfy.samplers.sampler_object(sampler_name)
|
|
return (sampler, )
|
|
|
|
class SamplerDPMPP_2M_SDE:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"solver_type": (['midpoint', 'heun'], ),
|
|
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
"noise_device": (['gpu', 'cpu'], ),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SAMPLER",)
|
|
CATEGORY = "sampling/custom_sampling/samplers"
|
|
|
|
FUNCTION = "get_sampler"
|
|
|
|
def get_sampler(self, solver_type, eta, s_noise, noise_device):
|
|
if noise_device == 'cpu':
|
|
sampler_name = "dpmpp_2m_sde"
|
|
else:
|
|
sampler_name = "dpmpp_2m_sde_gpu"
|
|
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
|
return (sampler, )
|
|
|
|
|
|
class SamplerDPMPP_SDE:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
"r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
|
"noise_device": (['gpu', 'cpu'], ),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("SAMPLER",)
|
|
CATEGORY = "sampling/custom_sampling/samplers"
|
|
|
|
FUNCTION = "get_sampler"
|
|
|
|
def get_sampler(self, eta, s_noise, r, noise_device):
|
|
if noise_device == 'cpu':
|
|
sampler_name = "dpmpp_sde"
|
|
else:
|
|
sampler_name = "dpmpp_sde_gpu"
|
|
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
|
return (sampler, )
|
|
|
|
class SamplerCustom:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"model": ("MODEL",),
|
|
"add_noise": ("BOOLEAN", {"default": True}),
|
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
|
"positive": ("CONDITIONING", ),
|
|
"negative": ("CONDITIONING", ),
|
|
"sampler": ("SAMPLER", ),
|
|
"sigmas": ("SIGMAS", ),
|
|
"latent_image": ("LATENT", ),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("LATENT","LATENT")
|
|
RETURN_NAMES = ("output", "denoised_output")
|
|
|
|
FUNCTION = "sample"
|
|
|
|
CATEGORY = "sampling/custom_sampling"
|
|
|
|
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
|
|
latent = latent_image
|
|
latent_image = latent["samples"]
|
|
if not add_noise:
|
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
|
else:
|
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
|
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
|
|
|
noise_mask = None
|
|
if "noise_mask" in latent:
|
|
noise_mask = latent["noise_mask"]
|
|
|
|
x0_output = {}
|
|
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
|
|
|
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
|
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
|
|
|
out = latent.copy()
|
|
out["samples"] = samples
|
|
if "x0" in x0_output:
|
|
out_denoised = latent.copy()
|
|
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
|
|
else:
|
|
out_denoised = out
|
|
return (out, out_denoised)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"SamplerCustom": SamplerCustom,
|
|
"BasicScheduler": BasicScheduler,
|
|
"KarrasScheduler": KarrasScheduler,
|
|
"ExponentialScheduler": ExponentialScheduler,
|
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
|
"VPScheduler": VPScheduler,
|
|
"SDTurboScheduler": SDTurboScheduler,
|
|
"KSamplerSelect": KSamplerSelect,
|
|
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
|
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
|
"SplitSigmas": SplitSigmas,
|
|
"FlipSigmas": FlipSigmas,
|
|
}
|