Cleanup SkipLayerGuidanceSD3 node.

This commit is contained in:
comfyanonymous 2024-10-29 10:11:46 -04:00
parent 954683d0db
commit 770ab200f2
1 changed files with 5 additions and 5 deletions

View File

@ -104,7 +104,7 @@ class SkipLayerGuidanceSD3:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ), return {"required": {"model": ("MODEL", ),
"layers": ("STRING", {"default": "7,8,9", "multiline": False}), "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
@ -119,11 +119,12 @@ class SkipLayerGuidanceSD3:
if layers == "" or layers == None: if layers == "" or layers == None:
return (model, ) return (model, )
# check if layer is comma separated integers # check if layer is comma separated integers
assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers"
def skip(args, extra_args): def skip(args, extra_args):
return args return args
model_sampling = model.get_model_object("model_sampling") model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args): def post_cfg_function(args):
model = args["model"] model = args["model"]
@ -137,10 +138,9 @@ class SkipLayerGuidanceSD3:
for layer in layers: for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent) model_sampling.percent_to_sigma(start_percent)
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
sigma_ = sigma[0].item() sigma_ = sigma[0].item()
if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result return cfg_result