diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index e95f20b9..6ef3c293 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,7 +3,9 @@ import comfy.sd import comfy.model_management import nodes import torch -import re +import comfy_extras.nodes_slg + + class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): @@ -23,6 +25,7 @@ class TripleCLIPLoader: clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) + class EmptySD3LatentImage: def __init__(self): self.device = comfy.model_management.intermediate_device() @@ -41,6 +44,7 @@ class EmptySD3LatentImage: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) return ({"samples":latent}, ) + class CLIPTextEncodeSD3: @classmethod def INPUT_TYPES(s): @@ -97,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): CATEGORY = "conditioning/controlnet" DEPRECATED = True -class SkipLayerGuidanceSD3: + +class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) @@ -112,48 +117,12 @@ class SkipLayerGuidanceSD3: "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) }} RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" + FUNCTION = "skip_guidance_sd3" CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, layers, scale, start_percent, end_percent): - if layers == "" or layers == None: - return (model, ) - # check if layer is comma separated integers - def skip(args, extra_args): - return args - - model_sampling = model.get_model_object("model_sampling") - sigma_start = model_sampling.percent_to_sigma(start_percent) - sigma_end = model_sampling.percent_to_sigma(end_percent) - - layers = re.findall(r'\d+', layers) - layers = [int(i) for i in layers] - - def post_cfg_function(args): - model = args["model"] - cond_pred = args["cond_denoised"] - cond = args["cond"] - cfg_result = args["denoised"] - sigma = args["sigma"] - x = args["input"] - model_options = args["model_options"].copy() - - for layer in layers: - model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) - model_sampling.percent_to_sigma(start_percent) - - sigma_ = sigma[0].item() - if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: - (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) - cfg_result = cfg_result + (cond_pred - slg) * scale - return cfg_result - - m = model.clone() - m.set_model_sampler_post_cfg_function(post_cfg_function) - - return (m, ) + def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): + return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py new file mode 100644 index 00000000..8a1181fc --- /dev/null +++ b/comfy_extras/nodes_slg.py @@ -0,0 +1,78 @@ +import comfy.model_patcher +import comfy.samplers +import re + + +class SkipLayerGuidanceDiT: + ''' + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Original experimental implementation for SD3 by Dango233@StabilityAI. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + EXPERIMENTAL = True + + DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." + + CATEGORY = "advanced/guidance" + + def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""): + # check if layer is comma separated integers + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + double_layers = re.findall(r'\d+', double_layers) + double_layers = [int(i) for i in double_layers] + + single_layers = re.findall(r'\d+', single_layers) + single_layers = [int(i) for i in single_layers] + + if len(double_layers) == 0 and len(single_layers) == 0: + return (model, ) + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in double_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) + + for layer in single_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer) + + model_sampling.percent_to_sigma(start_percent) + + sigma_ = sigma[0].item() + if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + return cfg_result + + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, +} diff --git a/nodes.py b/nodes.py index ea1b3faa..1ac817a2 100644 --- a/nodes.py +++ b/nodes.py @@ -2133,6 +2133,7 @@ def init_builtin_extra_nodes(): "nodes_lora_extract.py", "nodes_torch_compile.py", "nodes_mochi.py", + "nodes_slg.py", ] import_failed = []