Add a way to set patches that modify the attn2 output.
Change the transformer patches function format to be more future proof.
This commit is contained in:
parent
cd930d4e7f
commit
8883cb0f67
|
@ -260,7 +260,8 @@ class Gligen(nn.Module):
|
||||||
return r
|
return r
|
||||||
return func_lowvram
|
return func_lowvram
|
||||||
else:
|
else:
|
||||||
def func(key, x):
|
def func(x, extra_options):
|
||||||
|
key = extra_options["transformer_index"]
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
return module(x, objs)
|
return module(x, objs)
|
||||||
return func
|
return func
|
||||||
|
|
|
@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
current_index = None
|
extra_options = {}
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
current_index = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
|
if "block_index" in transformer_options:
|
||||||
|
extra_options["block_index"] = transformer_options["block_index"]
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
|
@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
|
||||||
context_attn1 = n
|
context_attn1 = n
|
||||||
value_attn1 = context_attn1
|
value_attn1 = context_attn1
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
|
@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
x = p(current_index, x)
|
x = p(x, extra_options)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
|
|
||||||
|
@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
|
||||||
patch = transformer_patches["attn2_patch"]
|
patch = transformer_patches["attn2_patch"]
|
||||||
value_attn2 = context_attn2
|
value_attn2 = context_attn2
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
|
if "attn2_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
|
@ -331,6 +331,9 @@ class ModelPatcher:
|
||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
|
|
|
@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
|
||||||
def __init__(self, hypernet, strength):
|
def __init__(self, hypernet, strength):
|
||||||
self.hypernet = hypernet
|
self.hypernet = hypernet
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
def __call__(self, current_index, q, k, v):
|
def __call__(self, q, k, v, extra_options):
|
||||||
dim = k.shape[-1]
|
dim = k.shape[-1]
|
||||||
if dim in self.hypernet:
|
if dim in self.hypernet:
|
||||||
hn = self.hypernet[dim]
|
hn = self.hypernet[dim]
|
||||||
|
|
Loading…
Reference in New Issue