Add a skip layer guidance node that can also skip single layers.

This one should work for skipping the single layers of models like Flux
and Auraflow.

If you want to see how these models work and how many double/single layers
they have see the "ModelMerge*" nodes for the specific model.
This commit is contained in:
comfyanonymous 2024-11-18 02:20:43 -05:00
parent d9f90965c8
commit 9a0a5d32ee
3 changed files with 89 additions and 41 deletions

View File

@ -3,7 +3,9 @@ import comfy.sd
import comfy.model_management import comfy.model_management
import nodes import nodes
import torch import torch
import re import comfy_extras.nodes_slg
class TripleCLIPLoader: class TripleCLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -23,6 +25,7 @@ class TripleCLIPLoader:
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)
class EmptySD3LatentImage: class EmptySD3LatentImage:
def __init__(self): def __init__(self):
self.device = comfy.model_management.intermediate_device() self.device = comfy.model_management.intermediate_device()
@ -41,6 +44,7 @@ class EmptySD3LatentImage:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, ) return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -97,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
DEPRECATED = True DEPRECATED = True
class SkipLayerGuidanceSD3:
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
@ -112,48 +117,12 @@ class SkipLayerGuidanceSD3:
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance" FUNCTION = "skip_guidance_sd3"
CATEGORY = "advanced/guidance" CATEGORY = "advanced/guidance"
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
def skip_guidance(self, model, layers, scale, start_percent, end_percent): return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
if layers == "" or layers == None:
return (model, )
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
layers = re.findall(r'\d+', layers)
layers = [int(i) for i in layers]
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

78
comfy_extras/nodes_slg.py Normal file
View File

@ -0,0 +1,78 @@
import comfy.model_patcher
import comfy.samplers
import re
class SkipLayerGuidanceDiT:
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r'\d+', double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r'\d+', single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in double_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
for layer in single_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
}

View File

@ -2133,6 +2133,7 @@ def init_builtin_extra_nodes():
"nodes_lora_extract.py", "nodes_lora_extract.py",
"nodes_torch_compile.py", "nodes_torch_compile.py",
"nodes_mochi.py", "nodes_mochi.py",
"nodes_slg.py",
] ]
import_failed = [] import_failed = []