Add a way to set patches that modify the attn2 output.

Change the transformer patches function format to be more future proof.
This commit is contained in:
comfyanonymous 2023-06-18 22:58:22 -04:00
parent cd930d4e7f
commit 8883cb0f67
4 changed files with 19 additions and 7 deletions

View File

@ -260,7 +260,8 @@ class Gligen(nn.Module):
return r return r
return func_lowvram return func_lowvram
else: else:
def func(key, x): def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key] module = self.module_list[key]
return module(x, objs) return module(x, objs)
return func return func

View File

@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}):
current_index = None extra_options = {}
if "current_index" in transformer_options: if "current_index" in transformer_options:
current_index = transformer_options["current_index"] extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
extra_options["block_index"] = transformer_options["block_index"]
if "patches" in transformer_options: if "patches" in transformer_options:
transformer_patches = transformer_options["patches"] transformer_patches = transformer_options["patches"]
else: else:
@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
context_attn1 = n context_attn1 = n
value_attn1 = context_attn1 value_attn1 = context_attn1
for p in patch: for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
if "middle_patch" in transformer_patches: if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"] patch = transformer_patches["middle_patch"]
for p in patch: for p in patch:
x = p(current_index, x) x = p(x, extra_options)
n = self.norm2(x) n = self.norm2(x)
@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
patch = transformer_patches["attn2_patch"] patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2 value_attn2 = context_attn2
for p in patch: for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
n = self.attn2(n, context=context_attn2, value=value_attn2) n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
for p in patch:
n = p(n, extra_options)
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)

View File

@ -331,6 +331,9 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch): def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch") self.set_model_patch(patch, "attn2_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:

View File

@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
def __init__(self, hypernet, strength): def __init__(self, hypernet, strength):
self.hypernet = hypernet self.hypernet = hypernet
self.strength = strength self.strength = strength
def __call__(self, current_index, q, k, v): def __call__(self, q, k, v, extra_options):
dim = k.shape[-1] dim = k.shape[-1]
if dim in self.hypernet: if dim in self.hypernet:
hn = self.hypernet[dim] hn = self.hypernet[dim]