Add basic PAG node.
This commit is contained in:
parent
258dbc06c3
commit
719fb2c81d
|
@ -18,6 +18,26 @@ def apply_weight_decompose(dora_scale, weight):
|
||||||
|
|
||||||
return weight * (dora_scale / weight_norm)
|
return weight * (dora_scale / weight_norm)
|
||||||
|
|
||||||
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
|
to = model_options["transformer_options"].copy()
|
||||||
|
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
else:
|
||||||
|
to["patches_replace"] = to["patches_replace"].copy()
|
||||||
|
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
else:
|
||||||
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
||||||
|
|
||||||
|
if transformer_index is not None:
|
||||||
|
block = (block_name, number, transformer_index)
|
||||||
|
else:
|
||||||
|
block = (block_name, number)
|
||||||
|
to["patches_replace"][name][block] = patch
|
||||||
|
model_options["transformer_options"] = to
|
||||||
|
return model_options
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
|
@ -109,16 +129,7 @@ class ModelPatcher:
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||||
to = self.model_options["transformer_options"]
|
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
||||||
if "patches_replace" not in to:
|
|
||||||
to["patches_replace"] = {}
|
|
||||||
if name not in to["patches_replace"]:
|
|
||||||
to["patches_replace"][name] = {}
|
|
||||||
if transformer_index is not None:
|
|
||||||
block = (block_name, number, transformer_index)
|
|
||||||
else:
|
|
||||||
block = (block_name, number)
|
|
||||||
to["patches_replace"][name][block] = patch
|
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
|
||||||
|
#If you want the one with more options see the above repo.
|
||||||
|
|
||||||
|
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
|
||||||
|
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.samplers
|
||||||
|
|
||||||
|
class PerturbedAttentionGuidance:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, scale):
|
||||||
|
unet_block = "middle"
|
||||||
|
unet_block_id = 0
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def perturbed_attention(q, k, v, extra_options, mask=None):
|
||||||
|
return v
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
model = args["model"]
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
if scale == 0:
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
# Replace Self-attention with PAG
|
||||||
|
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id)
|
||||||
|
(pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||||
|
|
||||||
|
return cfg_result + (cond_pred - pag) * scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
return (m,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PerturbedAttentionGuidance": PerturbedAttentionGuidance,
|
||||||
|
}
|
1
nodes.py
1
nodes.py
|
@ -1942,6 +1942,7 @@ def init_custom_nodes():
|
||||||
"nodes_differential_diffusion.py",
|
"nodes_differential_diffusion.py",
|
||||||
"nodes_ip2p.py",
|
"nodes_ip2p.py",
|
||||||
"nodes_model_merging_model_specific.py",
|
"nodes_model_merging_model_specific.py",
|
||||||
|
"nodes_pag.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|
Loading…
Reference in New Issue