Add DualCFGGuider used in IP2P models for example.

This commit is contained in:
comfyanonymous 2024-04-04 14:57:44 -04:00
parent cfbf3be54b
commit 5272fd4b03
1 changed files with 36 additions and 0 deletions

View File

@ -428,6 +428,41 @@ class CFGGuider:
guider.set_cfg(cfg) guider.set_cfg(cfg)
return (guider,) return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
self.cfg1 = cfg1
self.cfg2 = cfg2
def set_conds(self, positive, middle, negative):
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
def predict_noise(self, x, timestep, model_options={}, seed=None):
out = comfy.samplers.calc_cond_batch(self.inner_model, [self.conds.get("negative", None), self.conds.get("middle", None), self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"cond1": ("CONDITIONING", ),
"cond2": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
return (guider,)
class DisableNoise: class DisableNoise:
@classmethod @classmethod
@ -518,6 +553,7 @@ NODE_CLASS_MAPPINGS = {
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider, "CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
"BasicGuider": BasicGuider, "BasicGuider": BasicGuider,
"RandomNoise": RandomNoise, "RandomNoise": RandomNoise,
"DisableNoise": DisableNoise, "DisableNoise": DisableNoise,