Cleanup SkipLayerGuidanceSD3 node.
This commit is contained in:
parent
954683d0db
commit
770ab200f2
|
@ -104,7 +104,7 @@ class SkipLayerGuidanceSD3:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return {"required": {"model": ("MODEL", ),
|
||||||
"layers": ("STRING", {"default": "7,8,9", "multiline": False}),
|
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
@ -119,11 +119,12 @@ class SkipLayerGuidanceSD3:
|
||||||
if layers == "" or layers == None:
|
if layers == "" or layers == None:
|
||||||
return (model, )
|
return (model, )
|
||||||
# check if layer is comma separated integers
|
# check if layer is comma separated integers
|
||||||
assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers"
|
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
|
@ -137,10 +138,9 @@ class SkipLayerGuidanceSD3:
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||||
model_sampling.percent_to_sigma(start_percent)
|
model_sampling.percent_to_sigma(start_percent)
|
||||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
|
||||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
|
||||||
sigma_ = sigma[0].item()
|
sigma_ = sigma[0].item()
|
||||||
if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start:
|
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||||
return cfg_result
|
return cfg_result
|
||||||
|
|
Loading…
Reference in New Issue