From 9a0a5d32ee49e79beae551fe2a165ac1108378e9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Nov 2024 02:20:43 -0500 Subject: [PATCH] Add a skip layer guidance node that can also skip single layers. This one should work for skipping the single layers of models like Flux and Auraflow. If you want to see how these models work and how many double/single layers they have see the "ModelMerge*" nodes for the specific model. --- comfy_extras/nodes_sd3.py | 51 +++++-------------------- comfy_extras/nodes_slg.py | 78 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 89 insertions(+), 41 deletions(-) create mode 100644 comfy_extras/nodes_slg.py diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index e95f20b9..6ef3c293 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,7 +3,9 @@ import comfy.sd import comfy.model_management import nodes import torch -import re +import comfy_extras.nodes_slg + + class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): @@ -23,6 +25,7 @@ class TripleCLIPLoader: clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) + class EmptySD3LatentImage: def __init__(self): self.device = comfy.model_management.intermediate_device() @@ -41,6 +44,7 @@ class EmptySD3LatentImage: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) return ({"samples":latent}, ) + class CLIPTextEncodeSD3: @classmethod def INPUT_TYPES(s): @@ -97,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): CATEGORY = "conditioning/controlnet" DEPRECATED = True -class SkipLayerGuidanceSD3: + +class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) @@ -112,48 +117,12 @@ class SkipLayerGuidanceSD3: "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) }} RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" + FUNCTION = "skip_guidance_sd3" CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, layers, scale, start_percent, end_percent): - if layers == "" or layers == None: - return (model, ) - # check if layer is comma separated integers - def skip(args, extra_args): - return args - - model_sampling = model.get_model_object("model_sampling") - sigma_start = model_sampling.percent_to_sigma(start_percent) - sigma_end = model_sampling.percent_to_sigma(end_percent) - - layers = re.findall(r'\d+', layers) - layers = [int(i) for i in layers] - - def post_cfg_function(args): - model = args["model"] - cond_pred = args["cond_denoised"] - cond = args["cond"] - cfg_result = args["denoised"] - sigma = args["sigma"] - x = args["input"] - model_options = args["model_options"].copy() - - for layer in layers: - model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) - model_sampling.percent_to_sigma(start_percent) - - sigma_ = sigma[0].item() - if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: - (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) - cfg_result = cfg_result + (cond_pred - slg) * scale - return cfg_result - - m = model.clone() - m.set_model_sampler_post_cfg_function(post_cfg_function) - - return (m, ) + def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): + return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py new file mode 100644 index 00000000..8a1181fc --- /dev/null +++ b/comfy_extras/nodes_slg.py @@ -0,0 +1,78 @@ +import comfy.model_patcher +import comfy.samplers +import re + + +class SkipLayerGuidanceDiT: + ''' + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Original experimental implementation for SD3 by Dango233@StabilityAI. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + EXPERIMENTAL = True + + DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." + + CATEGORY = "advanced/guidance" + + def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""): + # check if layer is comma separated integers + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + double_layers = re.findall(r'\d+', double_layers) + double_layers = [int(i) for i in double_layers] + + single_layers = re.findall(r'\d+', single_layers) + single_layers = [int(i) for i in single_layers] + + if len(double_layers) == 0 and len(single_layers) == 0: + return (model, ) + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in double_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) + + for layer in single_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer) + + model_sampling.percent_to_sigma(start_percent) + + sigma_ = sigma[0].item() + if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + return cfg_result + + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, +} diff --git a/nodes.py b/nodes.py index ea1b3faa..1ac817a2 100644 --- a/nodes.py +++ b/nodes.py @@ -2133,6 +2133,7 @@ def init_builtin_extra_nodes(): "nodes_lora_extract.py", "nodes_torch_compile.py", "nodes_mochi.py", + "nodes_slg.py", ] import_failed = []