Add a way to set patches that modify the attn2 output.
Change the transformer patches function format to be more future proof.
This commit is contained in:
parent
cd930d4e7f
commit
8883cb0f67
|
@ -260,7 +260,8 @@ class Gligen(nn.Module):
|
|||
return r
|
||||
return func_lowvram
|
||||
else:
|
||||
def func(key, x):
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
|
|
@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module):
|
|||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
current_index = None
|
||||
extra_options = {}
|
||||
if "current_index" in transformer_options:
|
||||
current_index = transformer_options["current_index"]
|
||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||
if "block_index" in transformer_options:
|
||||
extra_options["block_index"] = transformer_options["block_index"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
|
@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
|
|||
context_attn1 = n
|
||||
value_attn1 = context_attn1
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
|
@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
|
|||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
x = p(current_index, x)
|
||||
x = p(x, extra_options)
|
||||
|
||||
n = self.norm2(x)
|
||||
|
||||
|
@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
|
|||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
|
|||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
|
|
|
@ -331,6 +331,9 @@ class ModelPatcher:
|
|||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
|
|
|
@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
|
|||
def __init__(self, hypernet, strength):
|
||||
self.hypernet = hypernet
|
||||
self.strength = strength
|
||||
def __call__(self, current_index, q, k, v):
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
dim = k.shape[-1]
|
||||
if dim in self.hypernet:
|
||||
hn = self.hypernet[dim]
|
||||
|
|
Loading…
Reference in New Issue