From 8883cb0f678d4a7ef58f9cfa7ae16f8b0b4b8da9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 18 Jun 2023 22:58:22 -0400 Subject: [PATCH] Add a way to set patches that modify the attn2 output. Change the transformer patches function format to be more future proof. --- comfy/gligen.py | 3 ++- comfy/ldm/modules/attention.py | 18 +++++++++++++----- comfy/sd.py | 3 +++ comfy_extras/nodes_hypernetwork.py | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 8c7cb432..fe3895c4 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -260,7 +260,8 @@ class Gligen(nn.Module): return r return func_lowvram else: - def func(key, x): + def func(x, extra_options): + key = extra_options["transformer_index"] module = self.module_list[key] return module(x, objs) return func diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 62707dfd..a0d69569 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): - current_index = None + extra_options = {} if "current_index" in transformer_options: - current_index = transformer_options["current_index"] + extra_options["transformer_index"] = transformer_options["current_index"] + if "block_index" in transformer_options: + extra_options["block_index"] = transformer_options["block_index"] if "patches" in transformer_options: transformer_patches = transformer_options["patches"] else: @@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module): context_attn1 = n value_attn1 = context_attn1 for p in patch: - n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) @@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module): if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: - x = p(current_index, x) + x = p(x, extra_options) n = self.norm2(x) @@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module): patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 for p in patch: - n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) n = self.attn2(n, context=context_attn2, value=value_attn2) + if "attn2_output_patch" in transformer_patches: + patch = transformer_patches["attn2_output_patch"] + for p in patch: + n = p(n, extra_options) + x += n x = self.ff(self.norm3(x)) + x return x @@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) diff --git a/comfy/sd.py b/comfy/sd.py index 7f04ae3a..e6cda513 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -331,6 +331,9 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index c19b5e4c..d16c49ae 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength): def __init__(self, hypernet, strength): self.hypernet = hypernet self.strength = strength - def __call__(self, current_index, q, k, v): + def __call__(self, q, k, v, extra_options): dim = k.shape[-1] if dim in self.hypernet: hn = self.hypernet[dim]