2024-08-09 07:21:10 +00:00
"""
This file is part of ComfyUI .
Copyright ( C ) 2024 Comfy
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < https : / / www . gnu . org / licenses / > .
"""
2024-08-27 17:55:37 +00:00
from __future__ import annotations
2023-08-25 21:11:51 +00:00
import comfy . utils
2024-08-22 21:20:39 +00:00
import comfy . model_management
import comfy . model_base
2024-03-10 15:37:08 +00:00
import logging
2024-08-22 21:12:00 +00:00
import torch
2023-08-25 21:11:51 +00:00
LORA_CLIP_MAP = {
" mlp.fc1 " : " mlp_fc1 " ,
" mlp.fc2 " : " mlp_fc2 " ,
" self_attn.k_proj " : " self_attn_k_proj " ,
" self_attn.q_proj " : " self_attn_q_proj " ,
" self_attn.v_proj " : " self_attn_v_proj " ,
" self_attn.out_proj " : " self_attn_out_proj " ,
}
def load_lora ( lora , to_load ) :
patch_dict = { }
loaded_keys = set ( )
for x in to_load :
alpha_name = " {} .alpha " . format ( x )
alpha = None
if alpha_name in lora . keys ( ) :
alpha = lora [ alpha_name ] . item ( )
loaded_keys . add ( alpha_name )
2024-03-25 22:09:23 +00:00
dora_scale_name = " {} .dora_scale " . format ( x )
dora_scale = None
if dora_scale_name in lora . keys ( ) :
dora_scale = lora [ dora_scale_name ]
loaded_keys . add ( dora_scale_name )
2024-11-21 12:19:17 +00:00
reshape_name = " {} .reshape_weight " . format ( x )
reshape = None
if reshape_name in lora . keys ( ) :
try :
reshape = lora [ reshape_name ] . tolist ( )
loaded_keys . add ( reshape_name )
except :
pass
2023-08-25 21:11:51 +00:00
regular_lora = " {} .lora_up.weight " . format ( x )
diffusers_lora = " {} _lora.up.weight " . format ( x )
2024-06-13 22:26:01 +00:00
diffusers2_lora = " {} .lora_B.weight " . format ( x )
2024-06-17 11:55:06 +00:00
diffusers3_lora = " {} .lora.up.weight " . format ( x )
2024-11-26 08:34:36 +00:00
mochi_lora = " {} .lora_B " . format ( x )
2023-08-25 21:11:51 +00:00
transformers_lora = " {} .lora_linear_layer.up.weight " . format ( x )
A_name = None
if regular_lora in lora . keys ( ) :
A_name = regular_lora
B_name = " {} .lora_down.weight " . format ( x )
mid_name = " {} .lora_mid.weight " . format ( x )
elif diffusers_lora in lora . keys ( ) :
A_name = diffusers_lora
B_name = " {} _lora.down.weight " . format ( x )
mid_name = None
2024-06-13 22:26:01 +00:00
elif diffusers2_lora in lora . keys ( ) :
A_name = diffusers2_lora
B_name = " {} .lora_A.weight " . format ( x )
mid_name = None
2024-06-17 11:55:06 +00:00
elif diffusers3_lora in lora . keys ( ) :
A_name = diffusers3_lora
B_name = " {} .lora.down.weight " . format ( x )
mid_name = None
2024-11-26 08:34:36 +00:00
elif mochi_lora in lora . keys ( ) :
A_name = mochi_lora
B_name = " {} .lora_A " . format ( x )
mid_name = None
2023-08-25 21:11:51 +00:00
elif transformers_lora in lora . keys ( ) :
A_name = transformers_lora
B_name = " {} .lora_linear_layer.down.weight " . format ( x )
mid_name = None
if A_name is not None :
mid = None
if mid_name is not None and mid_name in lora . keys ( ) :
mid = lora [ mid_name ]
loaded_keys . add ( mid_name )
2024-11-21 12:19:17 +00:00
patch_dict [ to_load [ x ] ] = ( " lora " , ( lora [ A_name ] , lora [ B_name ] , alpha , mid , dora_scale , reshape ) )
2023-08-25 21:11:51 +00:00
loaded_keys . add ( A_name )
loaded_keys . add ( B_name )
######## loha
hada_w1_a_name = " {} .hada_w1_a " . format ( x )
hada_w1_b_name = " {} .hada_w1_b " . format ( x )
hada_w2_a_name = " {} .hada_w2_a " . format ( x )
hada_w2_b_name = " {} .hada_w2_b " . format ( x )
hada_t1_name = " {} .hada_t1 " . format ( x )
hada_t2_name = " {} .hada_t2 " . format ( x )
if hada_w1_a_name in lora . keys ( ) :
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora . keys ( ) :
hada_t1 = lora [ hada_t1_name ]
hada_t2 = lora [ hada_t2_name ]
loaded_keys . add ( hada_t1_name )
loaded_keys . add ( hada_t2_name )
2024-03-25 22:09:23 +00:00
patch_dict [ to_load [ x ] ] = ( " loha " , ( lora [ hada_w1_a_name ] , lora [ hada_w1_b_name ] , alpha , lora [ hada_w2_a_name ] , lora [ hada_w2_b_name ] , hada_t1 , hada_t2 , dora_scale ) )
2023-08-25 21:11:51 +00:00
loaded_keys . add ( hada_w1_a_name )
loaded_keys . add ( hada_w1_b_name )
loaded_keys . add ( hada_w2_a_name )
loaded_keys . add ( hada_w2_b_name )
######## lokr
lokr_w1_name = " {} .lokr_w1 " . format ( x )
lokr_w2_name = " {} .lokr_w2 " . format ( x )
lokr_w1_a_name = " {} .lokr_w1_a " . format ( x )
lokr_w1_b_name = " {} .lokr_w1_b " . format ( x )
lokr_t2_name = " {} .lokr_t2 " . format ( x )
lokr_w2_a_name = " {} .lokr_w2_a " . format ( x )
lokr_w2_b_name = " {} .lokr_w2_b " . format ( x )
lokr_w1 = None
if lokr_w1_name in lora . keys ( ) :
lokr_w1 = lora [ lokr_w1_name ]
loaded_keys . add ( lokr_w1_name )
lokr_w2 = None
if lokr_w2_name in lora . keys ( ) :
lokr_w2 = lora [ lokr_w2_name ]
loaded_keys . add ( lokr_w2_name )
lokr_w1_a = None
if lokr_w1_a_name in lora . keys ( ) :
lokr_w1_a = lora [ lokr_w1_a_name ]
loaded_keys . add ( lokr_w1_a_name )
lokr_w1_b = None
if lokr_w1_b_name in lora . keys ( ) :
lokr_w1_b = lora [ lokr_w1_b_name ]
loaded_keys . add ( lokr_w1_b_name )
lokr_w2_a = None
if lokr_w2_a_name in lora . keys ( ) :
lokr_w2_a = lora [ lokr_w2_a_name ]
loaded_keys . add ( lokr_w2_a_name )
lokr_w2_b = None
if lokr_w2_b_name in lora . keys ( ) :
lokr_w2_b = lora [ lokr_w2_b_name ]
loaded_keys . add ( lokr_w2_b_name )
lokr_t2 = None
if lokr_t2_name in lora . keys ( ) :
lokr_t2 = lora [ lokr_t2_name ]
loaded_keys . add ( lokr_t2_name )
if ( lokr_w1 is not None ) or ( lokr_w2 is not None ) or ( lokr_w1_a is not None ) or ( lokr_w2_a is not None ) :
2024-03-25 22:09:23 +00:00
patch_dict [ to_load [ x ] ] = ( " lokr " , ( lokr_w1 , lokr_w2 , alpha , lokr_w1_a , lokr_w1_b , lokr_w2_a , lokr_w2_b , lokr_t2 , dora_scale ) )
2023-08-25 21:11:51 +00:00
2023-12-09 23:15:26 +00:00
#glora
a1_name = " {} .a1.weight " . format ( x )
a2_name = " {} .a2.weight " . format ( x )
b1_name = " {} .b1.weight " . format ( x )
b2_name = " {} .b2.weight " . format ( x )
if a1_name in lora :
2024-03-25 22:09:23 +00:00
patch_dict [ to_load [ x ] ] = ( " glora " , ( lora [ a1_name ] , lora [ a2_name ] , lora [ b1_name ] , lora [ b2_name ] , alpha , dora_scale ) )
2023-12-09 23:15:26 +00:00
loaded_keys . add ( a1_name )
loaded_keys . add ( a2_name )
loaded_keys . add ( b1_name )
loaded_keys . add ( b2_name )
2023-08-28 15:20:06 +00:00
w_norm_name = " {} .w_norm " . format ( x )
b_norm_name = " {} .b_norm " . format ( x )
w_norm = lora . get ( w_norm_name , None )
b_norm = lora . get ( b_norm_name , None )
if w_norm is not None :
loaded_keys . add ( w_norm_name )
2023-12-09 19:15:09 +00:00
patch_dict [ to_load [ x ] ] = ( " diff " , ( w_norm , ) )
2023-08-28 15:20:06 +00:00
if b_norm is not None :
loaded_keys . add ( b_norm_name )
2023-12-09 19:15:09 +00:00
patch_dict [ " {} .bias " . format ( to_load [ x ] [ : - len ( " .weight " ) ] ) ] = ( " diff " , ( b_norm , ) )
2023-08-28 15:20:06 +00:00
2023-11-09 03:05:31 +00:00
diff_name = " {} .diff " . format ( x )
diff_weight = lora . get ( diff_name , None )
if diff_weight is not None :
2023-12-09 19:15:09 +00:00
patch_dict [ to_load [ x ] ] = ( " diff " , ( diff_weight , ) )
2023-11-09 03:05:31 +00:00
loaded_keys . add ( diff_name )
diff_bias_name = " {} .diff_b " . format ( x )
diff_bias = lora . get ( diff_bias_name , None )
if diff_bias is not None :
2023-12-09 19:15:09 +00:00
patch_dict [ " {} .bias " . format ( to_load [ x ] [ : - len ( " .weight " ) ] ) ] = ( " diff " , ( diff_bias , ) )
2023-11-09 03:05:31 +00:00
loaded_keys . add ( diff_bias_name )
2024-11-21 12:19:17 +00:00
set_weight_name = " {} .set_weight " . format ( x )
set_weight = lora . get ( set_weight_name , None )
if set_weight is not None :
patch_dict [ to_load [ x ] ] = ( " set " , ( set_weight , ) )
loaded_keys . add ( set_weight_name )
2023-08-25 21:11:51 +00:00
for x in lora . keys ( ) :
if x not in loaded_keys :
2024-03-10 15:37:08 +00:00
logging . warning ( " lora key not loaded: {} " . format ( x ) )
2024-06-13 22:26:01 +00:00
2023-08-25 21:11:51 +00:00
return patch_dict
def model_lora_keys_clip ( model , key_map = { } ) :
sdk = model . state_dict ( ) . keys ( )
2024-09-07 06:13:13 +00:00
for k in sdk :
if k . endswith ( " .weight " ) :
key_map [ " text_encoders. {} " . format ( k [ : - len ( " .weight " ) ] ) ] = k #generic lora format without any weird key names
2023-08-25 21:11:51 +00:00
text_model_lora_key = " lora_te_text_model_encoder_layers_ {} _ {} "
clip_l_present = False
2024-09-08 13:31:41 +00:00
clip_g_present = False
2023-10-27 19:54:04 +00:00
for b in range ( 32 ) : #TODO: clean up
2023-08-25 21:11:51 +00:00
for c in LORA_CLIP_MAP :
2023-10-27 19:54:04 +00:00
k = " clip_h.transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
2023-08-25 21:11:51 +00:00
if k in sdk :
lora_key = text_model_lora_key . format ( b , LORA_CLIP_MAP [ c ] )
key_map [ lora_key ] = k
lora_key = " lora_te1_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] )
key_map [ lora_key ] = k
lora_key = " text_encoder.text_model.encoder.layers. {} . {} " . format ( b , c ) #diffusers lora
key_map [ lora_key ] = k
k = " clip_l.transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
if k in sdk :
2023-10-27 19:54:04 +00:00
lora_key = text_model_lora_key . format ( b , LORA_CLIP_MAP [ c ] )
key_map [ lora_key ] = k
2023-08-25 21:11:51 +00:00
lora_key = " lora_te1_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #SDXL base
key_map [ lora_key ] = k
clip_l_present = True
lora_key = " text_encoder.text_model.encoder.layers. {} . {} " . format ( b , c ) #diffusers lora
key_map [ lora_key ] = k
k = " clip_g.transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
if k in sdk :
2024-09-08 13:31:41 +00:00
clip_g_present = True
2023-08-25 21:11:51 +00:00
if clip_l_present :
lora_key = " lora_te2_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #SDXL base
key_map [ lora_key ] = k
lora_key = " text_encoder_2.text_model.encoder.layers. {} . {} " . format ( b , c ) #diffusers lora
key_map [ lora_key ] = k
else :
lora_key = " lora_te_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #TODO: test if this is correct for SDXL-Refiner
key_map [ lora_key ] = k
lora_key = " text_encoder.text_model.encoder.layers. {} . {} " . format ( b , c ) #diffusers lora
key_map [ lora_key ] = k
2024-02-23 17:21:20 +00:00
lora_key = " lora_prior_te_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #cascade lora: TODO put lora key prefix in the model config
key_map [ lora_key ] = k
2023-08-25 21:11:51 +00:00
2024-08-09 07:21:10 +00:00
for k in sdk :
if k . endswith ( " .weight " ) :
2024-09-08 13:31:41 +00:00
if k . startswith ( " t5xxl.transformer. " ) : #OneTrainer SD3 and Flux lora
2024-09-08 14:10:47 +00:00
l_key = k [ len ( " t5xxl.transformer. " ) : - len ( " .weight " ) ]
2024-09-08 13:31:41 +00:00
t5_index = 1
if clip_g_present :
t5_index + = 1
2024-09-08 14:10:47 +00:00
if clip_l_present :
t5_index + = 1
if t5_index == 2 :
key_map [ " lora_te {} _ {} " . format ( t5_index , l_key . replace ( " . " , " _ " ) ) ] = k #OneTrainer Flux
t5_index + = 1
2024-09-08 13:31:41 +00:00
2024-09-08 14:10:47 +00:00
key_map [ " lora_te {} _ {} " . format ( t5_index , l_key . replace ( " . " , " _ " ) ) ] = k
2024-08-09 07:21:10 +00:00
elif k . startswith ( " hydit_clip.transformer.bert. " ) : #HunyuanDiT Lora
l_key = k [ len ( " hydit_clip.transformer.bert. " ) : - len ( " .weight " ) ]
lora_key = " lora_te1_ {} " . format ( l_key . replace ( " . " , " _ " ) )
key_map [ lora_key ] = k
2024-02-25 04:50:46 +00:00
2024-02-25 12:20:31 +00:00
k = " clip_g.transformer.text_projection.weight "
2024-02-25 04:50:46 +00:00
if k in sdk :
2024-02-25 12:20:31 +00:00
key_map [ " lora_prior_te_text_projection " ] = k #cascade lora?
2024-02-25 04:50:46 +00:00
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
2024-06-22 17:08:04 +00:00
key_map [ " lora_te2_text_projection " ] = k #OneTrainer SD3 lora
k = " clip_l.transformer.text_projection.weight "
if k in sdk :
key_map [ " lora_te1_text_projection " ] = k #OneTrainer SD3 lora, not necessary but omits warning
2024-02-25 04:50:46 +00:00
2023-08-25 21:11:51 +00:00
return key_map
def model_lora_keys_unet ( model , key_map = { } ) :
2024-06-13 22:26:01 +00:00
sd = model . state_dict ( )
sdk = sd . keys ( )
2023-08-25 21:11:51 +00:00
for k in sdk :
2024-11-21 12:19:17 +00:00
if k . startswith ( " diffusion_model. " ) :
if k . endswith ( " .weight " ) :
key_lora = k [ len ( " diffusion_model. " ) : - len ( " .weight " ) ] . replace ( " . " , " _ " )
key_map [ " lora_unet_ {} " . format ( key_lora ) ] = k
key_map [ " lora_prior_unet_ {} " . format ( key_lora ) ] = k #cascade lora: TODO put lora key prefix in the model config
key_map [ " {} " . format ( k [ : - len ( " .weight " ) ] ) ] = k #generic lora format without any weird key names
else :
key_map [ " {} " . format ( k ) ] = k #generic lora format for not .weight without any weird key names
2023-08-25 21:11:51 +00:00
diffusers_keys = comfy . utils . unet_to_diffusers ( model . model_config . unet_config )
for k in diffusers_keys :
if k . endswith ( " .weight " ) :
unet_key = " diffusion_model. {} " . format ( diffusers_keys [ k ] )
key_lora = k [ : - len ( " .weight " ) ] . replace ( " . " , " _ " )
key_map [ " lora_unet_ {} " . format ( key_lora ) ] = unet_key
2024-09-30 10:03:27 +00:00
key_map [ " lycoris_ {} " . format ( key_lora ) ] = unet_key #simpletuner lycoris format
2023-08-25 21:11:51 +00:00
diffusers_lora_prefix = [ " " , " unet. " ]
for p in diffusers_lora_prefix :
diffusers_lora_key = " {} {} " . format ( p , k [ : - len ( " .weight " ) ] . replace ( " .to_ " , " .processor.to_ " ) )
if diffusers_lora_key . endswith ( " .to_out.0 " ) :
diffusers_lora_key = diffusers_lora_key [ : - 2 ]
key_map [ diffusers_lora_key ] = unet_key
2024-06-13 22:26:01 +00:00
if isinstance ( model , comfy . model_base . SD3 ) : #Diffusers lora SD3
2024-06-19 14:01:43 +00:00
diffusers_keys = comfy . utils . mmdit_to_diffusers ( model . model_config . unet_config , output_prefix = " diffusion_model. " )
for k in diffusers_keys :
if k . endswith ( " .weight " ) :
to = diffusers_keys [ k ]
key_lora = " transformer. {} " . format ( k [ : - len ( " .weight " ) ] ) #regular diffusers sd3 lora format
key_map [ key_lora ] = to
key_lora = " base_model.model. {} " . format ( k [ : - len ( " .weight " ) ] ) #format for flash-sd3 lora and others?
key_map [ key_lora ] = to
2024-06-13 22:26:01 +00:00
2024-06-22 17:08:04 +00:00
key_lora = " lora_transformer_ {} " . format ( k [ : - len ( " .weight " ) ] . replace ( " . " , " _ " ) ) #OneTrainer lora
key_map [ key_lora ] = to
2024-10-24 05:18:32 +00:00
key_lora = " lycoris_ {} " . format ( k [ : - len ( " .weight " ) ] . replace ( " . " , " _ " ) ) #simpletuner lycoris format
key_map [ key_lora ] = to
2024-07-13 17:51:40 +00:00
if isinstance ( model , comfy . model_base . AuraFlow ) : #Diffusers lora AuraFlow
diffusers_keys = comfy . utils . auraflow_to_diffusers ( model . model_config . unet_config , output_prefix = " diffusion_model. " )
for k in diffusers_keys :
if k . endswith ( " .weight " ) :
to = diffusers_keys [ k ]
key_lora = " transformer. {} " . format ( k [ : - len ( " .weight " ) ] ) #simpletrainer and probably regular diffusers lora format
key_map [ key_lora ] = to
2024-07-26 02:42:54 +00:00
if isinstance ( model , comfy . model_base . HunyuanDiT ) :
for k in sdk :
if k . startswith ( " diffusion_model. " ) and k . endswith ( " .weight " ) :
key_lora = k [ len ( " diffusion_model. " ) : - len ( " .weight " ) ]
key_map [ " base_model.model. {} " . format ( key_lora ) ] = k #official hunyuan lora format
2024-08-05 01:59:42 +00:00
if isinstance ( model , comfy . model_base . Flux ) : #Diffusers lora Flux
diffusers_keys = comfy . utils . flux_to_diffusers ( model . model_config . unet_config , output_prefix = " diffusion_model. " )
for k in diffusers_keys :
if k . endswith ( " .weight " ) :
to = diffusers_keys [ k ]
2024-08-20 16:05:13 +00:00
key_map [ " transformer. {} " . format ( k [ : - len ( " .weight " ) ] ) ] = to #simpletrainer and probably regular diffusers flux lora format
key_map [ " lycoris_ {} " . format ( k [ : - len ( " .weight " ) ] . replace ( " . " , " _ " ) ) ] = to #simpletrainer lycoris
2024-09-02 12:22:15 +00:00
key_map [ " lora_transformer_ {} " . format ( k [ : - len ( " .weight " ) ] . replace ( " . " , " _ " ) ) ] = to #onetrainer
2024-08-05 01:59:42 +00:00
2024-11-26 08:34:36 +00:00
if isinstance ( model , comfy . model_base . GenmoMochi ) :
for k in sdk :
if k . startswith ( " diffusion_model. " ) and k . endswith ( " .weight " ) : #Official Mochi lora format
key_lora = k [ len ( " diffusion_model. " ) : - len ( " .weight " ) ]
key_map [ " {} " . format ( key_lora ) ] = k
2023-08-25 21:11:51 +00:00
return key_map
2024-08-22 21:12:00 +00:00
2024-10-08 23:02:41 +00:00
def weight_decompose ( dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function ) :
2024-08-23 08:58:59 +00:00
dora_scale = comfy . model_management . cast_to_device ( dora_scale , weight . device , intermediate_dtype )
lora_diff * = alpha
2024-10-08 23:02:41 +00:00
weight_calc = weight + function ( lora_diff ) . type ( weight . dtype )
2024-08-23 08:58:59 +00:00
weight_norm = (
weight_calc . transpose ( 0 , 1 )
. reshape ( weight_calc . shape [ 1 ] , - 1 )
. norm ( dim = 1 , keepdim = True )
. reshape ( weight_calc . shape [ 1 ] , * [ 1 ] * ( weight_calc . dim ( ) - 1 ) )
. transpose ( 0 , 1 )
)
weight_calc * = ( dora_scale / weight_norm ) . type ( weight . dtype )
if strength != 1.0 :
weight_calc - = weight
weight + = strength * ( weight_calc )
else :
weight [ : ] = weight_calc
return weight
2024-08-27 17:55:37 +00:00
def pad_tensor_to_shape ( tensor : torch . Tensor , new_shape : list [ int ] ) - > torch . Tensor :
"""
Pad a tensor to a new shape with zeros .
Args :
tensor ( torch . Tensor ) : The original tensor to be padded .
new_shape ( List [ int ] ) : The desired shape of the padded tensor .
Returns :
torch . Tensor : A new tensor padded with zeros to the specified shape .
Note :
If the new shape is smaller than the original tensor in any dimension ,
the original tensor will be truncated in that dimension .
"""
if any ( [ new_shape [ i ] < tensor . shape [ i ] for i in range ( len ( new_shape ) ) ] ) :
raise ValueError ( " The new shape must be larger than the original tensor in all dimensions " )
if len ( new_shape ) != len ( tensor . shape ) :
raise ValueError ( " The new shape must have the same number of dimensions as the original tensor " )
# Create a new tensor filled with zeros
padded_tensor = torch . zeros ( new_shape , dtype = tensor . dtype , device = tensor . device )
# Create slicing tuples for both tensors
orig_slices = tuple ( slice ( 0 , dim ) for dim in tensor . shape )
new_slices = tuple ( slice ( 0 , dim ) for dim in tensor . shape )
# Copy the original tensor into the new tensor
padded_tensor [ new_slices ] = tensor [ orig_slices ]
return padded_tensor
2024-08-22 21:12:00 +00:00
def calculate_weight ( patches , weight , key , intermediate_dtype = torch . float32 ) :
for p in patches :
strength = p [ 0 ]
v = p [ 1 ]
strength_model = p [ 2 ]
offset = p [ 3 ]
function = p [ 4 ]
if function is None :
function = lambda a : a
old_weight = None
if offset is not None :
old_weight = weight
weight = weight . narrow ( offset [ 0 ] , offset [ 1 ] , offset [ 2 ] )
if strength_model != 1.0 :
weight * = strength_model
if isinstance ( v , list ) :
2024-10-20 10:24:31 +00:00
v = ( calculate_weight ( v [ 1 : ] , v [ 0 ] [ 1 ] ( comfy . model_management . cast_to_device ( v [ 0 ] [ 0 ] , weight . device , intermediate_dtype , copy = True ) , inplace = True ) , key , intermediate_dtype = intermediate_dtype ) , )
2024-08-22 21:12:00 +00:00
if len ( v ) == 1 :
patch_type = " diff "
elif len ( v ) == 2 :
patch_type = v [ 0 ]
v = v [ 1 ]
if patch_type == " diff " :
2024-08-27 17:55:37 +00:00
diff : torch . Tensor = v [ 0 ]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len ( v ) > 1 and v [ 1 ] [ ' pad_weight ' ]
if do_pad_weight and diff . shape != weight . shape :
logging . info ( " Pad weight {} from {} to shape: {} " . format ( key , weight . shape , diff . shape ) )
weight = pad_tensor_to_shape ( weight , diff . shape )
2024-08-22 21:12:00 +00:00
if strength != 0.0 :
2024-08-27 17:55:37 +00:00
if diff . shape != weight . shape :
logging . warning ( " WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {} " . format ( key , diff . shape , weight . shape ) )
2024-08-22 21:12:00 +00:00
else :
2024-08-27 17:55:37 +00:00
weight + = function ( strength * comfy . model_management . cast_to_device ( diff , weight . device , weight . dtype ) )
2024-11-21 12:19:17 +00:00
elif patch_type == " set " :
weight . copy_ ( v [ 0 ] )
2024-08-22 21:12:00 +00:00
elif patch_type == " lora " : #lora/locon
mat1 = comfy . model_management . cast_to_device ( v [ 0 ] , weight . device , intermediate_dtype )
mat2 = comfy . model_management . cast_to_device ( v [ 1 ] , weight . device , intermediate_dtype )
dora_scale = v [ 4 ]
2024-11-21 12:19:17 +00:00
reshape = v [ 5 ]
if reshape is not None :
weight = pad_tensor_to_shape ( weight , reshape )
2024-08-22 21:12:00 +00:00
if v [ 2 ] is not None :
alpha = v [ 2 ] / mat2 . shape [ 0 ]
else :
alpha = 1.0
if v [ 3 ] is not None :
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy . model_management . cast_to_device ( v [ 3 ] , weight . device , intermediate_dtype )
final_shape = [ mat2 . shape [ 1 ] , mat2 . shape [ 0 ] , mat3 . shape [ 2 ] , mat3 . shape [ 3 ] ]
mat2 = torch . mm ( mat2 . transpose ( 0 , 1 ) . flatten ( start_dim = 1 ) , mat3 . transpose ( 0 , 1 ) . flatten ( start_dim = 1 ) ) . reshape ( final_shape ) . transpose ( 0 , 1 )
try :
lora_diff = torch . mm ( mat1 . flatten ( start_dim = 1 ) , mat2 . flatten ( start_dim = 1 ) ) . reshape ( weight . shape )
if dora_scale is not None :
2024-10-08 23:02:41 +00:00
weight = weight_decompose ( dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
2024-08-22 21:12:00 +00:00
else :
weight + = function ( ( ( strength * alpha ) * lora_diff ) . type ( weight . dtype ) )
except Exception as e :
logging . error ( " ERROR {} {} {} " . format ( patch_type , key , e ) )
elif patch_type == " lokr " :
w1 = v [ 0 ]
w2 = v [ 1 ]
w1_a = v [ 3 ]
w1_b = v [ 4 ]
w2_a = v [ 5 ]
w2_b = v [ 6 ]
t2 = v [ 7 ]
dora_scale = v [ 8 ]
dim = None
if w1 is None :
dim = w1_b . shape [ 0 ]
w1 = torch . mm ( comfy . model_management . cast_to_device ( w1_a , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w1_b , weight . device , intermediate_dtype ) )
else :
w1 = comfy . model_management . cast_to_device ( w1 , weight . device , intermediate_dtype )
if w2 is None :
dim = w2_b . shape [ 0 ]
if t2 is None :
w2 = torch . mm ( comfy . model_management . cast_to_device ( w2_a , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2_b , weight . device , intermediate_dtype ) )
else :
w2 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' ,
comfy . model_management . cast_to_device ( t2 , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2_b , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2_a , weight . device , intermediate_dtype ) )
else :
w2 = comfy . model_management . cast_to_device ( w2 , weight . device , intermediate_dtype )
if len ( w2 . shape ) == 4 :
w1 = w1 . unsqueeze ( 2 ) . unsqueeze ( 2 )
if v [ 2 ] is not None and dim is not None :
alpha = v [ 2 ] / dim
else :
alpha = 1.0
try :
lora_diff = torch . kron ( w1 , w2 ) . reshape ( weight . shape )
if dora_scale is not None :
2024-10-08 23:02:41 +00:00
weight = weight_decompose ( dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
2024-08-22 21:12:00 +00:00
else :
weight + = function ( ( ( strength * alpha ) * lora_diff ) . type ( weight . dtype ) )
except Exception as e :
logging . error ( " ERROR {} {} {} " . format ( patch_type , key , e ) )
elif patch_type == " loha " :
w1a = v [ 0 ]
w1b = v [ 1 ]
if v [ 2 ] is not None :
alpha = v [ 2 ] / w1b . shape [ 0 ]
else :
alpha = 1.0
w2a = v [ 3 ]
w2b = v [ 4 ]
dora_scale = v [ 7 ]
if v [ 5 ] is not None : #cp decomposition
t1 = v [ 5 ]
t2 = v [ 6 ]
m1 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' ,
comfy . model_management . cast_to_device ( t1 , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w1b , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w1a , weight . device , intermediate_dtype ) )
m2 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' ,
comfy . model_management . cast_to_device ( t2 , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2b , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2a , weight . device , intermediate_dtype ) )
else :
m1 = torch . mm ( comfy . model_management . cast_to_device ( w1a , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w1b , weight . device , intermediate_dtype ) )
m2 = torch . mm ( comfy . model_management . cast_to_device ( w2a , weight . device , intermediate_dtype ) ,
comfy . model_management . cast_to_device ( w2b , weight . device , intermediate_dtype ) )
try :
lora_diff = ( m1 * m2 ) . reshape ( weight . shape )
if dora_scale is not None :
2024-10-08 23:02:41 +00:00
weight = weight_decompose ( dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
2024-08-22 21:12:00 +00:00
else :
weight + = function ( ( ( strength * alpha ) * lora_diff ) . type ( weight . dtype ) )
except Exception as e :
logging . error ( " ERROR {} {} {} " . format ( patch_type , key , e ) )
elif patch_type == " glora " :
dora_scale = v [ 5 ]
2024-09-03 21:02:19 +00:00
old_glora = False
if v [ 3 ] . shape [ 1 ] == v [ 2 ] . shape [ 0 ] == v [ 0 ] . shape [ 0 ] == v [ 1 ] . shape [ 1 ] :
rank = v [ 0 ] . shape [ 0 ]
old_glora = True
if v [ 3 ] . shape [ 0 ] == v [ 2 ] . shape [ 1 ] == v [ 0 ] . shape [ 1 ] == v [ 1 ] . shape [ 0 ] :
if old_glora and v [ 1 ] . shape [ 0 ] == weight . shape [ 0 ] and weight . shape [ 0 ] == weight . shape [ 1 ] :
pass
else :
old_glora = False
rank = v [ 1 ] . shape [ 0 ]
2024-08-22 21:12:00 +00:00
a1 = comfy . model_management . cast_to_device ( v [ 0 ] . flatten ( start_dim = 1 ) , weight . device , intermediate_dtype )
a2 = comfy . model_management . cast_to_device ( v [ 1 ] . flatten ( start_dim = 1 ) , weight . device , intermediate_dtype )
b1 = comfy . model_management . cast_to_device ( v [ 2 ] . flatten ( start_dim = 1 ) , weight . device , intermediate_dtype )
b2 = comfy . model_management . cast_to_device ( v [ 3 ] . flatten ( start_dim = 1 ) , weight . device , intermediate_dtype )
2024-09-03 21:02:19 +00:00
if v [ 4 ] is not None :
alpha = v [ 4 ] / rank
else :
alpha = 1.0
2024-08-22 21:12:00 +00:00
try :
2024-09-03 21:02:19 +00:00
if old_glora :
lora_diff = ( torch . mm ( b2 , b1 ) + torch . mm ( torch . mm ( weight . flatten ( start_dim = 1 ) . to ( dtype = intermediate_dtype ) , a2 ) , a1 ) ) . reshape ( weight . shape ) #old lycoris glora
else :
if weight . dim ( ) > 2 :
lora_diff = torch . einsum ( " o i ..., i j -> o j ... " , torch . einsum ( " o i ..., i j -> o j ... " , weight . to ( dtype = intermediate_dtype ) , a1 ) , a2 ) . reshape ( weight . shape )
else :
lora_diff = torch . mm ( torch . mm ( weight . to ( dtype = intermediate_dtype ) , a1 ) , a2 ) . reshape ( weight . shape )
lora_diff + = torch . mm ( b1 , b2 ) . reshape ( weight . shape )
2024-08-22 21:12:00 +00:00
if dora_scale is not None :
2024-10-08 23:02:41 +00:00
weight = weight_decompose ( dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
2024-08-22 21:12:00 +00:00
else :
weight + = function ( ( ( strength * alpha ) * lora_diff ) . type ( weight . dtype ) )
except Exception as e :
logging . error ( " ERROR {} {} {} " . format ( patch_type , key , e ) )
else :
logging . warning ( " patch type not recognized {} {} " . format ( patch_type , key ) )
if old_weight is not None :
weight = old_weight
return weight