diff --git a/comfy/samplers.py b/comfy/samplers.py index ffc1fe3a..1cdad736 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,6 +1,7 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch +import torch.nn.functional as F import enum from comfy import model_management import math @@ -60,10 +61,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option for t in range(rr): mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - conditionning = {} + conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None if 'control' in conds: @@ -82,7 +83,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditionning, area, control, patches) + return (input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -246,15 +247,71 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option return out_cond, out_uncond - if math.isclose(cond_scale, 1.0): + # if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. + if math.isclose(cond_scale, 1.0) and "sag" not in model_options: uncond = None - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} - return x - model_options["sampler_cfg_function"](args) - else: - return uncond + (cond - uncond) * cond_scale + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + cfg_result = x - model_options["sampler_cfg_function"](args) + + if "sag" in model_options: + assert uncond is not None, "SAG requires uncond guidance" + sag_scale = model_options["sag_scale"] + sag_sigma = model_options["sag_sigma"] + sag_threshold = model_options.get("sag_threshold", 1.0) + + # these methods are added by the sag patcher + uncond_attn = model.get_attn_scores() + mid_shape = model.get_mid_block_shape() + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) + cfg_result += (degraded - sag) * sag_scale + return cfg_result + +def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py new file mode 100644 index 00000000..1ec0c93a --- /dev/null +++ b/comfy_extras/nodes_sag.py @@ -0,0 +1,115 @@ +import torch +from torch import einsum +from einops import rearrange, repeat +import os +from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +class SagNode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + # set extra options on the model + m.model_options["sag"] = True + m.model_options["sag_scale"] = scale + m.model_options["sag_sigma"] = blur_sigma + + attn_scores = None + mid_block_shape = None + m.model.get_attn_scores = lambda: attn_scores + m.model.get_mid_block_shape = lambda: mid_block_shape + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + def set_model_patch_replace(patch, name, key): + to = m.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][key] = patch + set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) + # from diffusers: + # unet.mid_block.attentions[0].register_forward_hook() + def forward_hook(m, inp, out): + nonlocal mid_block_shape + mid_block_shape = out[0].shape[-2:] + m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "Self-Attention Guidance": SagNode, +} diff --git a/nodes.py b/nodes.py index db96e0e2..3d24750c 100644 --- a/nodes.py +++ b/nodes.py @@ -1867,6 +1867,7 @@ def init_custom_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_sag.py", ] for node_file in extras_files: