2023-01-03 06:53:32 +00:00
from inspect import isfunction
import math
import torch
import torch . nn . functional as F
from torch import nn , einsum
from einops import rearrange , repeat
from typing import Optional , Any
2023-05-04 22:07:41 +00:00
from . diffusionmodules . util import checkpoint
2023-01-03 06:53:32 +00:00
from . sub_quadratic_attention import efficient_dot_product_attention
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-03-03 08:27:33 +00:00
2023-03-13 15:36:48 +00:00
if model_management . xformers_enabled ( ) :
2023-01-03 06:53:32 +00:00
import xformers
import xformers . ops
2023-07-02 15:57:36 +00:00
from comfy . cli_args import args
2023-08-18 20:32:23 +00:00
import comfy . ops
2023-01-03 06:53:32 +00:00
# CrossAttn precision handling
2023-07-02 15:57:36 +00:00
if args . dont_upcast_attention :
print ( " disabling upcasting of attention " )
_ATTN_PRECISION = " fp16 "
else :
_ATTN_PRECISION = " fp32 "
2023-01-03 06:53:32 +00:00
2023-04-06 03:41:23 +00:00
2023-01-03 06:53:32 +00:00
def exists ( val ) :
return val is not None
def uniq ( arr ) :
return { el : True for el in arr } . keys ( )
def default ( val , d ) :
if exists ( val ) :
return val
2023-07-19 01:36:35 +00:00
return d
2023-01-03 06:53:32 +00:00
def max_neg_value ( t ) :
return - torch . finfo ( t . dtype ) . max
def init_ ( tensor ) :
dim = tensor . shape [ - 1 ]
std = 1 / math . sqrt ( dim )
tensor . uniform_ ( - std , std )
return tensor
# feedforward
class GEGLU ( nn . Module ) :
2023-08-18 20:32:23 +00:00
def __init__ ( self , dim_in , dim_out , dtype = None , device = None , operations = comfy . ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
2023-08-18 06:46:11 +00:00
self . proj = operations . Linear ( dim_in , dim_out * 2 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
def forward ( self , x ) :
x , gate = self . proj ( x ) . chunk ( 2 , dim = - 1 )
return x * F . gelu ( gate )
class FeedForward ( nn . Module ) :
2023-08-18 20:32:23 +00:00
def __init__ ( self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0. , dtype = None , device = None , operations = comfy . ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
inner_dim = int ( dim * mult )
dim_out = default ( dim_out , dim )
project_in = nn . Sequential (
2023-08-18 06:46:11 +00:00
operations . Linear ( dim , inner_dim , dtype = dtype , device = device ) ,
2023-01-03 06:53:32 +00:00
nn . GELU ( )
2023-08-18 06:46:11 +00:00
) if not glu else GEGLU ( dim , inner_dim , dtype = dtype , device = device , operations = operations )
2023-01-03 06:53:32 +00:00
self . net = nn . Sequential (
project_in ,
nn . Dropout ( dropout ) ,
2023-08-18 06:46:11 +00:00
operations . Linear ( inner_dim , dim_out , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
)
def forward ( self , x ) :
return self . net ( x )
def zero_module ( module ) :
"""
Zero out the parameters of a module and return it .
"""
for p in module . parameters ( ) :
p . detach ( ) . zero_ ( )
return module
2023-07-29 18:51:56 +00:00
def Normalize ( in_channels , dtype = None , device = None ) :
return torch . nn . GroupNorm ( num_groups = 32 , num_channels = in_channels , eps = 1e-6 , affine = True , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
def attention_basic ( q , k , v , heads , mask = None ) :
2023-10-22 07:59:53 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
scale = dim_head * * - 0.5
2023-10-11 19:47:53 +00:00
h = heads
2023-10-22 07:59:53 +00:00
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( b , - 1 , heads , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b * heads , - 1 , dim_head )
. contiguous ( ) ,
( q , k , v ) ,
)
2023-10-11 19:47:53 +00:00
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION == " fp32 " :
with torch . autocast ( enabled = False , device_type = ' cuda ' ) :
q , k = q . float ( ) , k . float ( )
sim = einsum ( ' b i d, b j d -> b i j ' , q , k ) * scale
else :
sim = einsum ( ' b i d, b j d -> b i j ' , q , k ) * scale
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
del q , k
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
if exists ( mask ) :
mask = rearrange ( mask , ' b ... -> b (...) ' )
max_neg_value = - torch . finfo ( sim . dtype ) . max
mask = repeat ( mask , ' b j -> (b h) () j ' , h = h )
sim . masked_fill_ ( ~ mask , max_neg_value )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
# attention, what we cannot get enough of
sim = sim . softmax ( dim = - 1 )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
out = einsum ( ' b i j, b j d -> b i d ' , sim . to ( v . dtype ) , v )
2023-10-22 07:59:53 +00:00
out = (
out . unsqueeze ( 0 )
. reshape ( b , heads , - 1 , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b , - 1 , heads * dim_head )
)
2023-10-11 19:47:53 +00:00
return out
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
def attention_sub_quad ( query , key , value , heads , mask = None ) :
2023-10-22 07:51:29 +00:00
b , _ , dim_head = query . shape
dim_head / / = heads
scale = dim_head * * - 0.5
query = query . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( b * heads , - 1 , dim_head )
value = value . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( b * heads , - 1 , dim_head )
key = key . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 3 , 1 ) . reshape ( b * heads , dim_head , - 1 )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
dtype = query . dtype
upcast_attention = _ATTN_PRECISION == " fp32 " and query . dtype != torch . float32
if upcast_attention :
bytes_per_token = torch . finfo ( torch . float32 ) . bits / / 8
else :
bytes_per_token = torch . finfo ( query . dtype ) . bits / / 8
batch_x_heads , q_tokens , _ = query . shape
2023-10-22 07:51:29 +00:00
_ , _ , k_tokens = key . shape
2023-10-11 19:47:53 +00:00
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
mem_free_total , mem_free_torch = model_management . get_free_memory ( query . device , True )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
kv_chunk_size_min = None
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3 :
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3 :
query_chunk_size_x = 1024 * 2
else :
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = ( int ( ( chunk_threshold_bytes / / ( batch_x_heads * bytes_per_token * query_chunk_size_x ) ) * 2.0 ) / / 1024 ) * 1024
if kv_chunk_size_x < 1024 :
kv_chunk_size_x = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes < = chunk_threshold_bytes :
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else :
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
hidden_states = efficient_dot_product_attention (
query ,
2023-10-22 07:51:29 +00:00
key ,
2023-10-11 19:47:53 +00:00
value ,
query_chunk_size = query_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = False ,
upcast_attention = upcast_attention ,
)
hidden_states = hidden_states . to ( dtype )
hidden_states = hidden_states . unflatten ( 0 , ( - 1 , heads ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 )
return hidden_states
def attention_split ( q , k , v , heads , mask = None ) :
2023-10-22 07:51:29 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
scale = dim_head * * - 0.5
2023-10-11 19:47:53 +00:00
h = heads
2023-10-22 07:51:29 +00:00
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( b , - 1 , heads , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b * heads , - 1 , dim_head )
. contiguous ( ) ,
( q , k , v ) ,
)
2023-10-11 19:47:53 +00:00
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
mem_free_total = model_management . get_free_memory ( q . device )
2023-10-26 00:17:28 +00:00
if _ATTN_PRECISION == " fp32 " :
element_size = 4
else :
element_size = q . element_size ( )
2023-10-11 19:47:53 +00:00
gb = 1024 * * 3
2023-10-26 00:17:28 +00:00
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * element_size
modifier = 3 if element_size == 2 else 2.5
2023-10-11 19:47:53 +00:00
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True :
try :
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
if _ATTN_PRECISION == " fp32 " :
with torch . autocast ( enabled = False , device_type = ' cuda ' ) :
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] . float ( ) , k . float ( ) ) * scale
2023-01-03 06:53:32 +00:00
else :
2023-10-11 19:47:53 +00:00
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k ) * scale
first_op_done = True
s2 = s1 . softmax ( dim = - 1 ) . to ( v . dtype )
del s1
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
break
except model_management . OOM_EXCEPTION as e :
if first_op_done == False :
model_management . soft_empty_cache ( True )
if cleared_cache == False :
cleared_cache = True
print ( " out of memory error, emptying cache and trying again " )
continue
steps * = 2
if steps > 64 :
2023-01-03 06:53:32 +00:00
raise e
2023-10-11 19:47:53 +00:00
print ( " out of memory error, increasing steps and trying again " , steps )
else :
raise e
del q , k , v
2023-10-22 07:51:29 +00:00
r1 = (
r1 . unsqueeze ( 0 )
. reshape ( b , heads , - 1 , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b , - 1 , heads * dim_head )
)
return r1
2023-10-11 19:47:53 +00:00
def attention_xformers ( q , k , v , heads , mask = None ) :
2023-10-21 17:23:03 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
2023-10-11 19:47:53 +00:00
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
2023-10-21 17:23:03 +00:00
. reshape ( b , - 1 , heads , dim_head )
2023-10-11 19:47:53 +00:00
. permute ( 0 , 2 , 1 , 3 )
2023-10-21 17:23:03 +00:00
. reshape ( b * heads , - 1 , dim_head )
2023-10-11 19:47:53 +00:00
. contiguous ( ) ,
( q , k , v ) ,
)
# actually compute the attention, what we cannot get enough of
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None )
if exists ( mask ) :
raise NotImplementedError
out = (
out . unsqueeze ( 0 )
2023-10-21 17:23:03 +00:00
. reshape ( b , heads , - 1 , dim_head )
2023-10-11 19:47:53 +00:00
. permute ( 0 , 2 , 1 , 3 )
2023-10-21 17:23:03 +00:00
. reshape ( b , - 1 , heads * dim_head )
2023-10-11 19:47:53 +00:00
)
return out
def attention_pytorch ( q , k , v , heads , mask = None ) :
b , _ , dim_head = q . shape
dim_head / / = heads
q , k , v = map (
lambda t : t . view ( b , - 1 , heads , dim_head ) . transpose ( 1 , 2 ) ,
( q , k , v ) ,
)
2023-10-12 00:24:17 +00:00
out = torch . nn . functional . scaled_dot_product_attention ( q , k , v , attn_mask = mask , dropout_p = 0.0 , is_causal = False )
2023-10-11 19:47:53 +00:00
out = (
out . transpose ( 1 , 2 ) . reshape ( b , - 1 , heads * dim_head )
)
return out
2023-10-16 06:31:24 +00:00
2023-10-11 19:47:53 +00:00
optimized_attention = attention_basic
2023-10-16 06:31:24 +00:00
optimized_attention_masked = attention_basic
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
if model_management . xformers_enabled ( ) :
print ( " Using xformers cross attention " )
optimized_attention = attention_xformers
elif model_management . pytorch_attention_enabled ( ) :
print ( " Using pytorch cross attention " )
optimized_attention = attention_pytorch
else :
if args . use_split_cross_attention :
print ( " Using split optimization for cross attention " )
optimized_attention = attention_split
else :
print ( " Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention " )
optimized_attention = attention_sub_quad
2023-01-03 06:53:32 +00:00
2023-10-16 06:31:24 +00:00
if model_management . pytorch_attention_enabled ( ) :
optimized_attention_masked = attention_pytorch
2023-02-22 03:16:13 +00:00
class CrossAttention ( nn . Module ) :
2023-08-18 20:32:23 +00:00
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None , operations = comfy . ops ) :
2023-03-02 22:01:20 +00:00
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
self . heads = heads
self . dim_head = dim_head
2023-08-18 06:46:11 +00:00
self . to_q = operations . Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operations . Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operations . Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
2023-03-02 22:01:20 +00:00
2023-08-18 06:46:11 +00:00
self . to_out = nn . Sequential ( operations . Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
2023-03-02 22:01:20 +00:00
2023-04-23 16:35:25 +00:00
def forward ( self , x , context = None , value = None , mask = None ) :
2023-03-02 22:01:20 +00:00
q = self . to_q ( x )
context = default ( context , x )
k = self . to_k ( context )
2023-04-23 16:35:25 +00:00
if value is not None :
v = self . to_v ( value )
del value
else :
v = self . to_v ( context )
2023-03-02 22:01:20 +00:00
2023-10-16 06:31:24 +00:00
if mask is None :
out = optimized_attention ( q , k , v , self . heads )
else :
out = optimized_attention_masked ( q , k , v , self . heads , mask )
2023-03-02 22:01:20 +00:00
return self . to_out ( out )
2023-03-05 19:14:54 +00:00
2023-01-03 06:53:32 +00:00
class BasicTransformerBlock ( nn . Module ) :
def __init__ ( self , dim , n_heads , d_head , dropout = 0. , context_dim = None , gated_ff = True , checkpoint = True ,
2023-08-18 20:32:23 +00:00
disable_self_attn = False , dtype = None , device = None , operations = comfy . ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
self . disable_self_attn = disable_self_attn
2023-02-28 18:48:52 +00:00
self . attn1 = CrossAttention ( query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout ,
2023-08-18 06:46:11 +00:00
context_dim = context_dim if self . disable_self_attn else None , dtype = dtype , device = device , operations = operations ) # is a self-attention if not self.disable_self_attn
self . ff = FeedForward ( dim , dropout = dropout , glu = gated_ff , dtype = dtype , device = device , operations = operations )
2023-02-28 18:48:52 +00:00
self . attn2 = CrossAttention ( query_dim = dim , context_dim = context_dim ,
2023-08-18 06:46:11 +00:00
heads = n_heads , dim_head = d_head , dropout = dropout , dtype = dtype , device = device , operations = operations ) # is self-attn if context is none
2023-07-29 18:51:56 +00:00
self . norm1 = nn . LayerNorm ( dim , dtype = dtype , device = device )
self . norm2 = nn . LayerNorm ( dim , dtype = dtype , device = device )
self . norm3 = nn . LayerNorm ( dim , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
self . checkpoint = checkpoint
2023-06-24 00:17:45 +00:00
self . n_heads = n_heads
self . d_head = d_head
2023-01-03 06:53:32 +00:00
2023-03-31 17:04:39 +00:00
def forward ( self , x , context = None , transformer_options = { } ) :
return checkpoint ( self . _forward , ( x , context , transformer_options ) , self . parameters ( ) , self . checkpoint )
2023-01-03 06:53:32 +00:00
2023-03-31 17:04:39 +00:00
def _forward ( self , x , context = None , transformer_options = { } ) :
2023-06-19 02:58:22 +00:00
extra_options = { }
2023-06-24 00:17:45 +00:00
block = None
block_index = 0
2023-04-19 13:36:19 +00:00
if " current_index " in transformer_options :
2023-06-19 02:58:22 +00:00
extra_options [ " transformer_index " ] = transformer_options [ " current_index " ]
if " block_index " in transformer_options :
2023-06-24 00:17:45 +00:00
block_index = transformer_options [ " block_index " ]
extra_options [ " block_index " ] = block_index
2023-06-21 17:22:01 +00:00
if " original_shape " in transformer_options :
extra_options [ " original_shape " ] = transformer_options [ " original_shape " ]
2023-06-24 00:17:45 +00:00
if " block " in transformer_options :
block = transformer_options [ " block " ]
extra_options [ " block " ] = block
2023-09-27 16:04:07 +00:00
if " cond_or_uncond " in transformer_options :
extra_options [ " cond_or_uncond " ] = transformer_options [ " cond_or_uncond " ]
2023-04-19 13:36:19 +00:00
if " patches " in transformer_options :
transformer_patches = transformer_options [ " patches " ]
else :
transformer_patches = { }
2023-06-24 00:17:45 +00:00
extra_options [ " n_heads " ] = self . n_heads
extra_options [ " dim_head " ] = self . d_head
if " patches_replace " in transformer_options :
transformer_patches_replace = transformer_options [ " patches_replace " ]
else :
transformer_patches_replace = { }
2023-03-31 21:19:58 +00:00
n = self . norm1 ( x )
2023-04-23 16:35:25 +00:00
if self . disable_self_attn :
context_attn1 = context
else :
context_attn1 = None
value_attn1 = None
if " attn1_patch " in transformer_patches :
patch = transformer_patches [ " attn1_patch " ]
if context_attn1 is None :
context_attn1 = n
value_attn1 = context_attn1
for p in patch :
2023-06-19 02:58:22 +00:00
n , context_attn1 , value_attn1 = p ( n , context_attn1 , value_attn1 , extra_options )
2023-04-23 16:35:25 +00:00
2023-06-24 15:02:38 +00:00
if block is not None :
transformer_block = ( block [ 0 ] , block [ 1 ] , block_index )
else :
transformer_block = None
2023-06-24 00:17:45 +00:00
attn1_replace_patch = transformer_patches_replace . get ( " attn1 " , { } )
block_attn1 = transformer_block
if block_attn1 not in attn1_replace_patch :
block_attn1 = block
if block_attn1 in attn1_replace_patch :
if context_attn1 is None :
context_attn1 = n
value_attn1 = n
n = self . attn1 . to_q ( n )
context_attn1 = self . attn1 . to_k ( context_attn1 )
value_attn1 = self . attn1 . to_v ( value_attn1 )
n = attn1_replace_patch [ block_attn1 ] ( n , context_attn1 , value_attn1 , extra_options )
n = self . attn1 . to_out ( n )
2023-03-31 21:19:58 +00:00
else :
2023-04-23 16:35:25 +00:00
n = self . attn1 ( n , context = context_attn1 , value = value_attn1 )
2023-03-31 21:19:58 +00:00
2023-06-24 00:17:45 +00:00
if " attn1_output_patch " in transformer_patches :
patch = transformer_patches [ " attn1_output_patch " ]
for p in patch :
n = p ( n , extra_options )
2023-03-31 21:19:58 +00:00
x + = n
2023-04-19 13:36:19 +00:00
if " middle_patch " in transformer_patches :
patch = transformer_patches [ " middle_patch " ]
for p in patch :
2023-06-19 02:58:22 +00:00
x = p ( x , extra_options )
2023-04-19 13:36:19 +00:00
2023-03-31 21:19:58 +00:00
n = self . norm2 ( x )
2023-04-23 16:35:25 +00:00
context_attn2 = context
value_attn2 = None
if " attn2_patch " in transformer_patches :
patch = transformer_patches [ " attn2_patch " ]
value_attn2 = context_attn2
for p in patch :
2023-06-19 02:58:22 +00:00
n , context_attn2 , value_attn2 = p ( n , context_attn2 , value_attn2 , extra_options )
2023-04-23 16:35:25 +00:00
2023-06-24 00:17:45 +00:00
attn2_replace_patch = transformer_patches_replace . get ( " attn2 " , { } )
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch :
block_attn2 = block
if block_attn2 in attn2_replace_patch :
if value_attn2 is None :
value_attn2 = context_attn2
n = self . attn2 . to_q ( n )
context_attn2 = self . attn2 . to_k ( context_attn2 )
value_attn2 = self . attn2 . to_v ( value_attn2 )
n = attn2_replace_patch [ block_attn2 ] ( n , context_attn2 , value_attn2 , extra_options )
n = self . attn2 . to_out ( n )
else :
n = self . attn2 ( n , context = context_attn2 , value = value_attn2 )
2023-03-31 21:19:58 +00:00
2023-06-19 02:58:22 +00:00
if " attn2_output_patch " in transformer_patches :
patch = transformer_patches [ " attn2_output_patch " ]
for p in patch :
n = p ( n , extra_options )
2023-03-31 21:19:58 +00:00
x + = n
2023-01-03 06:53:32 +00:00
x = self . ff ( self . norm3 ( x ) ) + x
return x
class SpatialTransformer ( nn . Module ) :
"""
Transformer block for image - like data .
First , project the input ( aka embedding )
and reshape to b , t , d .
Then apply standard transformer action .
Finally , reshape to image
NEW : use_linear for more efficiency instead of the 1 x1 convs
"""
def __init__ ( self , in_channels , n_heads , d_head ,
depth = 1 , dropout = 0. , context_dim = None ,
disable_self_attn = False , use_linear = False ,
2023-08-18 20:32:23 +00:00
use_checkpoint = True , dtype = None , device = None , operations = comfy . ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
if exists ( context_dim ) and not isinstance ( context_dim , list ) :
2023-06-22 17:03:50 +00:00
context_dim = [ context_dim ] * depth
2023-01-03 06:53:32 +00:00
self . in_channels = in_channels
inner_dim = n_heads * d_head
2023-07-29 18:51:56 +00:00
self . norm = Normalize ( in_channels , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
if not use_linear :
2023-08-18 06:46:11 +00:00
self . proj_in = operations . Conv2d ( in_channels ,
2023-01-03 06:53:32 +00:00
inner_dim ,
kernel_size = 1 ,
stride = 1 ,
2023-07-29 18:51:56 +00:00
padding = 0 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
else :
2023-08-18 06:46:11 +00:00
self . proj_in = operations . Linear ( in_channels , inner_dim , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
self . transformer_blocks = nn . ModuleList (
[ BasicTransformerBlock ( inner_dim , n_heads , d_head , dropout = dropout , context_dim = context_dim [ d ] ,
2023-08-18 06:46:11 +00:00
disable_self_attn = disable_self_attn , checkpoint = use_checkpoint , dtype = dtype , device = device , operations = operations )
2023-01-03 06:53:32 +00:00
for d in range ( depth ) ]
)
if not use_linear :
2023-08-18 06:46:11 +00:00
self . proj_out = operations . Conv2d ( inner_dim , in_channels ,
2023-01-03 06:53:32 +00:00
kernel_size = 1 ,
stride = 1 ,
2023-07-29 18:51:56 +00:00
padding = 0 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
else :
2023-08-18 06:46:11 +00:00
self . proj_out = operations . Linear ( in_channels , inner_dim , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
self . use_linear = use_linear
2023-03-31 17:04:39 +00:00
def forward ( self , x , context = None , transformer_options = { } ) :
2023-01-03 06:53:32 +00:00
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance ( context , list ) :
2023-06-22 17:03:50 +00:00
context = [ context ] * len ( self . transformer_blocks )
2023-01-03 06:53:32 +00:00
b , c , h , w = x . shape
x_in = x
x = self . norm ( x )
if not self . use_linear :
x = self . proj_in ( x )
x = rearrange ( x , ' b c h w -> b (h w) c ' ) . contiguous ( )
if self . use_linear :
x = self . proj_in ( x )
for i , block in enumerate ( self . transformer_blocks ) :
2023-06-19 02:58:22 +00:00
transformer_options [ " block_index " ] = i
2023-03-31 17:04:39 +00:00
x = block ( x , context = context [ i ] , transformer_options = transformer_options )
2023-01-03 06:53:32 +00:00
if self . use_linear :
x = self . proj_out ( x )
x = rearrange ( x , ' b (h w) c -> b c h w ' , h = h , w = w ) . contiguous ( )
if not self . use_linear :
x = self . proj_out ( x )
return x + x_in