Add a skip layer guidance node that can also skip single layers.
This one should work for skipping the single layers of models like Flux and Auraflow. If you want to see how these models work and how many double/single layers they have see the "ModelMerge*" nodes for the specific model.
This commit is contained in:
parent
d9f90965c8
commit
9a0a5d32ee
|
@ -3,7 +3,9 @@ import comfy.sd
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import re
|
import comfy_extras.nodes_slg
|
||||||
|
|
||||||
|
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
@ -23,6 +25,7 @@ class TripleCLIPLoader:
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
|
||||||
class EmptySD3LatentImage:
|
class EmptySD3LatentImage:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.device = comfy.model_management.intermediate_device()
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
@ -41,6 +44,7 @@ class EmptySD3LatentImage:
|
||||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
@ -97,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
DEPRECATED = True
|
DEPRECATED = True
|
||||||
|
|
||||||
class SkipLayerGuidanceSD3:
|
|
||||||
|
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
|
@ -112,48 +117,12 @@ class SkipLayerGuidanceSD3:
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "skip_guidance"
|
FUNCTION = "skip_guidance_sd3"
|
||||||
|
|
||||||
CATEGORY = "advanced/guidance"
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
||||||
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||||
if layers == "" or layers == None:
|
|
||||||
return (model, )
|
|
||||||
# check if layer is comma separated integers
|
|
||||||
def skip(args, extra_args):
|
|
||||||
return args
|
|
||||||
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
|
||||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
|
||||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
|
||||||
|
|
||||||
layers = re.findall(r'\d+', layers)
|
|
||||||
layers = [int(i) for i in layers]
|
|
||||||
|
|
||||||
def post_cfg_function(args):
|
|
||||||
model = args["model"]
|
|
||||||
cond_pred = args["cond_denoised"]
|
|
||||||
cond = args["cond"]
|
|
||||||
cfg_result = args["denoised"]
|
|
||||||
sigma = args["sigma"]
|
|
||||||
x = args["input"]
|
|
||||||
model_options = args["model_options"].copy()
|
|
||||||
|
|
||||||
for layer in layers:
|
|
||||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
|
||||||
model_sampling.percent_to_sigma(start_percent)
|
|
||||||
|
|
||||||
sigma_ = sigma[0].item()
|
|
||||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
|
||||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
|
||||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
|
||||||
return cfg_result
|
|
||||||
|
|
||||||
m = model.clone()
|
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
|
||||||
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.samplers
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class SkipLayerGuidanceDiT:
|
||||||
|
'''
|
||||||
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
|
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||||
|
'''
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL", ),
|
||||||
|
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||||
|
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||||
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "skip_guidance"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
||||||
|
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
|
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
|
||||||
|
# check if layer is comma separated integers
|
||||||
|
def skip(args, extra_args):
|
||||||
|
return args
|
||||||
|
|
||||||
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
double_layers = re.findall(r'\d+', double_layers)
|
||||||
|
double_layers = [int(i) for i in double_layers]
|
||||||
|
|
||||||
|
single_layers = re.findall(r'\d+', single_layers)
|
||||||
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
|
return (model, )
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
model = args["model"]
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
x = args["input"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
|
||||||
|
for layer in double_layers:
|
||||||
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||||
|
|
||||||
|
for layer in single_layers:
|
||||||
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
|
||||||
|
|
||||||
|
model_sampling.percent_to_sigma(start_percent)
|
||||||
|
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||||
|
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||||
|
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
||||||
|
}
|
Loading…
Reference in New Issue