Add sampler_pre_cfg_function (#3979)
* Update samplers.py * Update model_patcher.py
This commit is contained in:
parent
c3db344746
commit
f1a01c2c7e
|
@ -57,6 +57,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
|
||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
model_options["disable_cfg1_optimization"] = True
|
||||||
|
return model_options
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
@ -130,6 +136,9 @@ class ModelPatcher:
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
|
|
@ -275,6 +275,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||||
|
|
||||||
conds = [cond, uncond_]
|
conds = [cond, uncond_]
|
||||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
|
|
||||||
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
|
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||||
|
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||||
|
out = fn(args)
|
||||||
|
|
||||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue