2024-07-11 20:51:06 +00:00
#AuraFlow MMDiT
#Originally written by the AuraFlow Authors
import math
import torch
import torch . nn as nn
import torch . nn . functional as F
from comfy . ldm . modules . attention import optimized_attention
2024-07-30 09:03:20 +00:00
import comfy . ops
2024-08-04 19:45:43 +00:00
import comfy . ldm . common_dit
2024-07-11 20:51:06 +00:00
def modulate ( x , shift , scale ) :
return x * ( 1 + scale . unsqueeze ( 1 ) ) + shift . unsqueeze ( 1 )
def find_multiple ( n : int , k : int ) - > int :
if n % k == 0 :
return n
return n + k - ( n % k )
class MLP ( nn . Module ) :
def __init__ ( self , dim , hidden_dim = None , dtype = None , device = None , operations = None ) - > None :
super ( ) . __init__ ( )
if hidden_dim is None :
hidden_dim = 4 * dim
n_hidden = int ( 2 * hidden_dim / 3 )
n_hidden = find_multiple ( n_hidden , 256 )
self . c_fc1 = operations . Linear ( dim , n_hidden , bias = False , dtype = dtype , device = device )
self . c_fc2 = operations . Linear ( dim , n_hidden , bias = False , dtype = dtype , device = device )
self . c_proj = operations . Linear ( n_hidden , dim , bias = False , dtype = dtype , device = device )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
x = F . silu ( self . c_fc1 ( x ) ) * self . c_fc2 ( x )
x = self . c_proj ( x )
return x
class MultiHeadLayerNorm ( nn . Module ) :
def __init__ ( self , hidden_size = None , eps = 1e-5 , dtype = None , device = None ) :
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
super ( ) . __init__ ( )
self . weight = nn . Parameter ( torch . empty ( hidden_size , dtype = dtype , device = device ) )
self . variance_epsilon = eps
def forward ( self , hidden_states ) :
input_dtype = hidden_states . dtype
hidden_states = hidden_states . to ( torch . float32 )
mean = hidden_states . mean ( - 1 , keepdim = True )
variance = ( hidden_states - mean ) . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = ( hidden_states - mean ) * torch . rsqrt (
variance + self . variance_epsilon
)
hidden_states = self . weight . to ( torch . float32 ) * hidden_states
return hidden_states . to ( input_dtype )
class SingleAttention ( nn . Module ) :
def __init__ ( self , dim , n_heads , mh_qknorm = False , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . n_heads = n_heads
self . head_dim = dim / / n_heads
# this is for cond
self . w1q = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1k = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1v = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1o = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . q_norm1 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
self . k_norm1 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
#@torch.compile()
def forward ( self , c ) :
bsz , seqlen1 , _ = c . shape
q , k , v = self . w1q ( c ) , self . w1k ( c ) , self . w1v ( c )
q = q . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
k = k . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
v = v . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
q , k = self . q_norm1 ( q ) , self . k_norm1 ( k )
output = optimized_attention ( q . permute ( 0 , 2 , 1 , 3 ) , k . permute ( 0 , 2 , 1 , 3 ) , v . permute ( 0 , 2 , 1 , 3 ) , self . n_heads , skip_reshape = True )
c = self . w1o ( output )
return c
class DoubleAttention ( nn . Module ) :
def __init__ ( self , dim , n_heads , mh_qknorm = False , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . n_heads = n_heads
self . head_dim = dim / / n_heads
# this is for cond
self . w1q = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1k = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1v = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w1o = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
# this is for x
self . w2q = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w2k = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w2v = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . w2o = operations . Linear ( dim , dim , bias = False , dtype = dtype , device = device )
self . q_norm1 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
self . k_norm1 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
self . q_norm2 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
self . k_norm2 = (
MultiHeadLayerNorm ( ( self . n_heads , self . head_dim ) , dtype = dtype , device = device )
if mh_qknorm
else operations . LayerNorm ( self . head_dim , elementwise_affine = False , dtype = dtype , device = device )
)
#@torch.compile()
def forward ( self , c , x ) :
bsz , seqlen1 , _ = c . shape
bsz , seqlen2 , _ = x . shape
seqlen = seqlen1 + seqlen2
cq , ck , cv = self . w1q ( c ) , self . w1k ( c ) , self . w1v ( c )
cq = cq . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
ck = ck . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
cv = cv . view ( bsz , seqlen1 , self . n_heads , self . head_dim )
cq , ck = self . q_norm1 ( cq ) , self . k_norm1 ( ck )
xq , xk , xv = self . w2q ( x ) , self . w2k ( x ) , self . w2v ( x )
xq = xq . view ( bsz , seqlen2 , self . n_heads , self . head_dim )
xk = xk . view ( bsz , seqlen2 , self . n_heads , self . head_dim )
xv = xv . view ( bsz , seqlen2 , self . n_heads , self . head_dim )
xq , xk = self . q_norm2 ( xq ) , self . k_norm2 ( xk )
# concat all
q , k , v = (
torch . cat ( [ cq , xq ] , dim = 1 ) ,
torch . cat ( [ ck , xk ] , dim = 1 ) ,
torch . cat ( [ cv , xv ] , dim = 1 ) ,
)
output = optimized_attention ( q . permute ( 0 , 2 , 1 , 3 ) , k . permute ( 0 , 2 , 1 , 3 ) , v . permute ( 0 , 2 , 1 , 3 ) , self . n_heads , skip_reshape = True )
c , x = output . split ( [ seqlen1 , seqlen2 ] , dim = 1 )
c = self . w1o ( c )
x = self . w2o ( x )
return c , x
class MMDiTBlock ( nn . Module ) :
def __init__ ( self , dim , heads = 8 , global_conddim = 1024 , is_last = False , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . normC1 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
self . normC2 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
if not is_last :
self . mlpC = MLP ( dim , hidden_dim = dim * 4 , dtype = dtype , device = device , operations = operations )
self . modC = nn . Sequential (
nn . SiLU ( ) ,
operations . Linear ( global_conddim , 6 * dim , bias = False , dtype = dtype , device = device ) ,
)
else :
self . modC = nn . Sequential (
nn . SiLU ( ) ,
operations . Linear ( global_conddim , 2 * dim , bias = False , dtype = dtype , device = device ) ,
)
self . normX1 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
self . normX2 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
self . mlpX = MLP ( dim , hidden_dim = dim * 4 , dtype = dtype , device = device , operations = operations )
self . modX = nn . Sequential (
nn . SiLU ( ) ,
operations . Linear ( global_conddim , 6 * dim , bias = False , dtype = dtype , device = device ) ,
)
self . attn = DoubleAttention ( dim , heads , dtype = dtype , device = device , operations = operations )
self . is_last = is_last
#@torch.compile()
def forward ( self , c , x , global_cond , * * kwargs ) :
cres , xres = c , x
cshift_msa , cscale_msa , cgate_msa , cshift_mlp , cscale_mlp , cgate_mlp = (
self . modC ( global_cond ) . chunk ( 6 , dim = 1 )
)
c = modulate ( self . normC1 ( c ) , cshift_msa , cscale_msa )
# xpath
xshift_msa , xscale_msa , xgate_msa , xshift_mlp , xscale_mlp , xgate_mlp = (
self . modX ( global_cond ) . chunk ( 6 , dim = 1 )
)
x = modulate ( self . normX1 ( x ) , xshift_msa , xscale_msa )
# attention
c , x = self . attn ( c , x )
c = self . normC2 ( cres + cgate_msa . unsqueeze ( 1 ) * c )
c = cgate_mlp . unsqueeze ( 1 ) * self . mlpC ( modulate ( c , cshift_mlp , cscale_mlp ) )
c = cres + c
x = self . normX2 ( xres + xgate_msa . unsqueeze ( 1 ) * x )
x = xgate_mlp . unsqueeze ( 1 ) * self . mlpX ( modulate ( x , xshift_mlp , xscale_mlp ) )
x = xres + x
return c , x
class DiTBlock ( nn . Module ) :
# like MMDiTBlock, but it only has X
def __init__ ( self , dim , heads = 8 , global_conddim = 1024 , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . norm1 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
self . norm2 = operations . LayerNorm ( dim , elementwise_affine = False , dtype = dtype , device = device )
self . modCX = nn . Sequential (
nn . SiLU ( ) ,
operations . Linear ( global_conddim , 6 * dim , bias = False , dtype = dtype , device = device ) ,
)
self . attn = SingleAttention ( dim , heads , dtype = dtype , device = device , operations = operations )
self . mlp = MLP ( dim , hidden_dim = dim * 4 , dtype = dtype , device = device , operations = operations )
#@torch.compile()
def forward ( self , cx , global_cond , * * kwargs ) :
cxres = cx
shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = self . modCX (
global_cond
) . chunk ( 6 , dim = 1 )
cx = modulate ( self . norm1 ( cx ) , shift_msa , scale_msa )
cx = self . attn ( cx )
cx = self . norm2 ( cxres + gate_msa . unsqueeze ( 1 ) * cx )
mlpout = self . mlp ( modulate ( cx , shift_mlp , scale_mlp ) )
cx = gate_mlp . unsqueeze ( 1 ) * mlpout
cx = cxres + cx
return cx
class TimestepEmbedder ( nn . Module ) :
def __init__ ( self , hidden_size , frequency_embedding_size = 256 , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . mlp = nn . Sequential (
operations . Linear ( frequency_embedding_size , hidden_size , dtype = dtype , device = device ) ,
nn . SiLU ( ) ,
operations . Linear ( hidden_size , hidden_size , dtype = dtype , device = device ) ,
)
self . frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding ( t , dim , max_period = 10000 ) :
half = dim / / 2
freqs = 1000 * torch . exp (
- math . log ( max_period ) * torch . arange ( start = 0 , end = half ) / half
) . to ( t . device )
args = t [ : , None ] * freqs [ None ]
embedding = torch . cat ( [ torch . cos ( args ) , torch . sin ( args ) ] , dim = - 1 )
if dim % 2 :
embedding = torch . cat (
[ embedding , torch . zeros_like ( embedding [ : , : 1 ] ) ] , dim = - 1
)
return embedding
#@torch.compile()
def forward ( self , t , dtype ) :
t_freq = self . timestep_embedding ( t , self . frequency_embedding_size ) . to ( dtype )
t_emb = self . mlp ( t_freq )
return t_emb
class MMDiT ( nn . Module ) :
def __init__ (
self ,
in_channels = 4 ,
out_channels = 4 ,
patch_size = 2 ,
dim = 3072 ,
n_layers = 36 ,
n_double_layers = 4 ,
n_heads = 12 ,
global_conddim = 3072 ,
cond_seq_dim = 2048 ,
max_seq = 32 * 32 ,
device = None ,
dtype = None ,
operations = None ,
) :
super ( ) . __init__ ( )
self . dtype = dtype
self . t_embedder = TimestepEmbedder ( global_conddim , dtype = dtype , device = device , operations = operations )
self . cond_seq_linear = operations . Linear (
cond_seq_dim , dim , bias = False , dtype = dtype , device = device
) # linear for something like text sequence.
self . init_x_linear = operations . Linear (
patch_size * patch_size * in_channels , dim , dtype = dtype , device = device
) # init linear for patchified image.
self . positional_encoding = nn . Parameter ( torch . empty ( 1 , max_seq , dim , dtype = dtype , device = device ) )
self . register_tokens = nn . Parameter ( torch . empty ( 1 , 8 , dim , dtype = dtype , device = device ) )
self . double_layers = nn . ModuleList ( [ ] )
self . single_layers = nn . ModuleList ( [ ] )
for idx in range ( n_double_layers ) :
self . double_layers . append (
MMDiTBlock ( dim , n_heads , global_conddim , is_last = ( idx == n_layers - 1 ) , dtype = dtype , device = device , operations = operations )
)
for idx in range ( n_double_layers , n_layers ) :
self . single_layers . append (
DiTBlock ( dim , n_heads , global_conddim , dtype = dtype , device = device , operations = operations )
)
self . final_linear = operations . Linear (
dim , patch_size * patch_size * out_channels , bias = False , dtype = dtype , device = device
)
self . modF = nn . Sequential (
nn . SiLU ( ) ,
operations . Linear ( global_conddim , 2 * dim , bias = False , dtype = dtype , device = device ) ,
)
self . out_channels = out_channels
self . patch_size = patch_size
self . n_double_layers = n_double_layers
self . n_layers = n_layers
self . h_max = round ( max_seq * * 0.5 )
self . w_max = round ( max_seq * * 0.5 )
@torch.no_grad ( )
def extend_pe ( self , init_dim = ( 16 , 16 ) , target_dim = ( 64 , 64 ) ) :
# extend pe
pe_data = self . positional_encoding . data . squeeze ( 0 ) [ : init_dim [ 0 ] * init_dim [ 1 ] ]
pe_as_2d = pe_data . view ( init_dim [ 0 ] , init_dim [ 1 ] , - 1 ) . permute ( 2 , 0 , 1 )
# now we need to extend this to target_dim. for this we will use interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F . interpolate (
pe_as_2d . unsqueeze ( 0 ) , size = target_dim , mode = " bilinear "
)
pe_new = pe_as_2d . squeeze ( 0 ) . permute ( 1 , 2 , 0 ) . flatten ( 0 , 1 )
self . positional_encoding . data = pe_new . unsqueeze ( 0 ) . contiguous ( )
self . h_max , self . w_max = target_dim
print ( " PE extended to " , target_dim )
def pe_selection_index_based_on_dim ( self , h , w ) :
h_p , w_p = h / / self . patch_size , w / / self . patch_size
original_pe_indexes = torch . arange ( self . positional_encoding . shape [ 1 ] )
original_pe_indexes = original_pe_indexes . view ( self . h_max , self . w_max )
starth = self . h_max / / 2 - h_p / / 2
endh = starth + h_p
startw = self . w_max / / 2 - w_p / / 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes [
starth : endh , startw : endw
]
return original_pe_indexes . flatten ( )
def unpatchify ( self , x , h , w ) :
c = self . out_channels
p = self . patch_size
x = x . reshape ( shape = ( x . shape [ 0 ] , h , w , p , p , c ) )
x = torch . einsum ( " nhwpqc->nchpwq " , x )
imgs = x . reshape ( shape = ( x . shape [ 0 ] , c , h * p , w * p ) )
return imgs
def patchify ( self , x ) :
B , C , H , W = x . size ( )
2024-08-04 19:45:43 +00:00
x = comfy . ldm . common_dit . pad_to_patch_size ( x , ( self . patch_size , self . patch_size ) )
2024-07-11 20:51:06 +00:00
x = x . view (
B ,
C ,
( H + 1 ) / / self . patch_size ,
self . patch_size ,
( W + 1 ) / / self . patch_size ,
self . patch_size ,
)
x = x . permute ( 0 , 2 , 4 , 1 , 3 , 5 ) . flatten ( - 3 ) . flatten ( 1 , 2 )
return x
def apply_pos_embeds ( self , x , h , w ) :
h = ( h + 1 ) / / self . patch_size
w = ( w + 1 ) / / self . patch_size
max_dim = max ( h , w )
cur_dim = self . h_max
2024-07-30 09:03:20 +00:00
pos_encoding = comfy . ops . cast_to_input ( self . positional_encoding . reshape ( 1 , cur_dim , cur_dim , - 1 ) , x )
2024-07-11 20:51:06 +00:00
if max_dim > cur_dim :
pos_encoding = F . interpolate ( pos_encoding . movedim ( - 1 , 1 ) , ( max_dim , max_dim ) , mode = " bilinear " ) . movedim ( 1 , - 1 )
cur_dim = max_dim
from_h = ( cur_dim - h ) / / 2
from_w = ( cur_dim - w ) / / 2
pos_encoding = pos_encoding [ : , from_h : from_h + h , from_w : from_w + w ]
return x + pos_encoding . reshape ( 1 , - 1 , self . positional_encoding . shape [ - 1 ] )
2024-11-17 13:19:59 +00:00
def forward ( self , x , timestep , context , transformer_options = { } , * * kwargs ) :
patches_replace = transformer_options . get ( " patches_replace " , { } )
2024-07-11 20:51:06 +00:00
# patchify x, add PE
b , c , h , w = x . shape
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
# print(pe_indexes, pe_indexes.shape)
x = self . init_x_linear ( self . patchify ( x ) ) # B, T_x, D
x = self . apply_pos_embeds ( x , h , w )
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
# process conditions for MMDiT Blocks
c_seq = context # B, T_c, D_c
t = timestep
c = self . cond_seq_linear ( c_seq ) # B, T_c, D
2024-07-30 09:03:20 +00:00
c = torch . cat ( [ comfy . ops . cast_to_input ( self . register_tokens , c ) . repeat ( c . size ( 0 ) , 1 , 1 ) , c ] , dim = 1 )
2024-07-11 20:51:06 +00:00
global_cond = self . t_embedder ( t , x . dtype ) # B, D
2024-11-17 13:19:59 +00:00
blocks_replace = patches_replace . get ( " dit " , { } )
2024-07-11 20:51:06 +00:00
if len ( self . double_layers ) > 0 :
2024-11-17 13:19:59 +00:00
for i , layer in enumerate ( self . double_layers ) :
if ( " double_block " , i ) in blocks_replace :
def block_wrap ( args ) :
out = { }
out [ " txt " ] , out [ " img " ] = layer ( args [ " txt " ] ,
args [ " img " ] ,
args [ " vec " ] )
return out
out = blocks_replace [ ( " double_block " , i ) ] ( { " img " : x , " txt " : c , " vec " : global_cond } , { " original_block " : block_wrap } )
c = out [ " txt " ]
x = out [ " img " ]
else :
c , x = layer ( c , x , global_cond , * * kwargs )
2024-07-11 20:51:06 +00:00
if len ( self . single_layers ) > 0 :
c_len = c . size ( 1 )
cx = torch . cat ( [ c , x ] , dim = 1 )
2024-11-17 13:19:59 +00:00
for i , layer in enumerate ( self . single_layers ) :
if ( " single_block " , i ) in blocks_replace :
def block_wrap ( args ) :
out = { }
out [ " img " ] = layer ( args [ " img " ] , args [ " vec " ] )
return out
out = blocks_replace [ ( " single_block " , i ) ] ( { " img " : cx , " vec " : global_cond } , { " original_block " : block_wrap } )
cx = out [ " img " ]
else :
cx = layer ( cx , global_cond , * * kwargs )
2024-07-11 20:51:06 +00:00
x = cx [ : , c_len : ]
fshift , fscale = self . modF ( global_cond ) . chunk ( 2 , dim = 1 )
x = modulate ( x , fshift , fscale )
x = self . final_linear ( x )
x = self . unpatchify ( x , ( h + 1 ) / / self . patch_size , ( w + 1 ) / / self . patch_size ) [ : , : , : h , : w ]
return x