2023-04-23 16:35:25 +00:00
import comfy . utils
import folder_paths
import torch
2024-03-11 20:24:47 +00:00
import logging
2023-04-23 16:35:25 +00:00
def load_hypernetwork_patch ( path , strength ) :
sd = comfy . utils . load_torch_file ( path , safe_load = True )
activation_func = sd . get ( ' activation_func ' , ' linear ' )
is_layer_norm = sd . get ( ' is_layer_norm ' , False )
use_dropout = sd . get ( ' use_dropout ' , False )
activate_output = sd . get ( ' activate_output ' , False )
last_layer_dropout = sd . get ( ' last_layer_dropout ' , False )
2023-04-24 06:36:06 +00:00
valid_activation = {
" linear " : torch . nn . Identity ,
" relu " : torch . nn . ReLU ,
" leakyrelu " : torch . nn . LeakyReLU ,
" elu " : torch . nn . ELU ,
" swish " : torch . nn . Hardswish ,
" tanh " : torch . nn . Tanh ,
" sigmoid " : torch . nn . Sigmoid ,
2023-05-05 04:16:57 +00:00
" softsign " : torch . nn . Softsign ,
2023-10-17 16:08:03 +00:00
" mish " : torch . nn . Mish ,
2023-04-24 06:36:06 +00:00
}
if activation_func not in valid_activation :
2024-03-11 20:24:47 +00:00
logging . error ( " Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {} " . format ( path , activation_func , is_layer_norm , use_dropout , activate_output , last_layer_dropout ) )
2023-04-23 16:35:25 +00:00
return None
out = { }
for d in sd :
try :
dim = int ( d )
except :
continue
output = [ ]
for index in [ 0 , 1 ] :
attn_weights = sd [ dim ] [ index ]
keys = attn_weights . keys ( )
linears = filter ( lambda a : a . endswith ( " .weight " ) , keys )
2023-04-24 06:36:06 +00:00
linears = list ( map ( lambda a : a [ : - len ( " .weight " ) ] , linears ) )
2023-04-23 16:35:25 +00:00
layers = [ ]
2023-10-17 16:08:03 +00:00
i = 0
while i < len ( linears ) :
2023-04-24 06:36:06 +00:00
lin_name = linears [ i ]
last_layer = ( i == ( len ( linears ) - 1 ) )
penultimate_layer = ( i == ( len ( linears ) - 2 ) )
2023-04-23 16:35:25 +00:00
lin_weight = attn_weights [ ' {} .weight ' . format ( lin_name ) ]
lin_bias = attn_weights [ ' {} .bias ' . format ( lin_name ) ]
layer = torch . nn . Linear ( lin_weight . shape [ 1 ] , lin_weight . shape [ 0 ] )
layer . load_state_dict ( { " weight " : lin_weight , " bias " : lin_bias } )
2023-04-24 06:36:06 +00:00
layers . append ( layer )
if activation_func != " linear " :
if ( not last_layer ) or ( activate_output ) :
layers . append ( valid_activation [ activation_func ] ( ) )
if is_layer_norm :
2023-10-17 16:08:03 +00:00
i + = 1
ln_name = linears [ i ]
ln_weight = attn_weights [ ' {} .weight ' . format ( ln_name ) ]
ln_bias = attn_weights [ ' {} .bias ' . format ( ln_name ) ]
ln = torch . nn . LayerNorm ( ln_weight . shape [ 0 ] )
ln . load_state_dict ( { " weight " : ln_weight , " bias " : ln_bias } )
layers . append ( ln )
2023-04-24 06:36:06 +00:00
if use_dropout :
if ( not last_layer ) and ( not penultimate_layer or last_layer_dropout ) :
layers . append ( torch . nn . Dropout ( p = 0.3 ) )
2023-10-17 16:08:03 +00:00
i + = 1
2023-04-23 16:35:25 +00:00
output . append ( torch . nn . Sequential ( * layers ) )
out [ dim ] = torch . nn . ModuleList ( output )
class hypernetwork_patch :
def __init__ ( self , hypernet , strength ) :
self . hypernet = hypernet
self . strength = strength
2023-06-19 02:58:22 +00:00
def __call__ ( self , q , k , v , extra_options ) :
2023-04-23 16:35:25 +00:00
dim = k . shape [ - 1 ]
if dim in self . hypernet :
hn = self . hypernet [ dim ]
k = k + hn [ 0 ] ( k ) * self . strength
v = v + hn [ 1 ] ( v ) * self . strength
return q , k , v
def to ( self , device ) :
for d in self . hypernet . keys ( ) :
self . hypernet [ d ] = self . hypernet [ d ] . to ( device )
return self
return hypernetwork_patch ( out , strength )
class HypernetworkLoader :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " model " : ( " MODEL " , ) ,
" hypernetwork_name " : ( folder_paths . get_filename_list ( " hypernetworks " ) , ) ,
" strength " : ( " FLOAT " , { " default " : 1.0 , " min " : - 10.0 , " max " : 10.0 , " step " : 0.01 } ) ,
} }
RETURN_TYPES = ( " MODEL " , )
FUNCTION = " load_hypernetwork "
2023-04-24 07:08:51 +00:00
CATEGORY = " loaders "
2023-04-23 16:35:25 +00:00
def load_hypernetwork ( self , model , hypernetwork_name , strength ) :
2024-09-17 07:57:17 +00:00
hypernetwork_path = folder_paths . get_full_path_or_raise ( " hypernetworks " , hypernetwork_name )
2023-04-23 16:35:25 +00:00
model_hypernetwork = model . clone ( )
patch = load_hypernetwork_patch ( hypernetwork_path , strength )
if patch is not None :
model_hypernetwork . set_model_attn1_patch ( patch )
model_hypernetwork . set_model_attn2_patch ( patch )
return ( model_hypernetwork , )
NODE_CLASS_MAPPINGS = {
" HypernetworkLoader " : HypernetworkLoader
}