2023-01-03 06:53:32 +00:00
import torch
2024-02-16 18:29:04 +00:00
from enum import Enum
2024-03-10 15:37:08 +00:00
import logging
2023-01-03 06:53:32 +00:00
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-10-17 18:51:51 +00:00
from . ldm . models . autoencoder import AutoencoderKL , AutoencodingEngine
2024-02-16 11:30:39 +00:00
from . ldm . cascade . stage_a import StageA
2024-02-19 09:06:49 +00:00
from . ldm . cascade . stage_c_coder import StageC_coder
2024-06-15 16:14:56 +00:00
from . ldm . audio . autoencoder import AudioOobleckVAE
2024-10-26 10:54:00 +00:00
import comfy . ldm . genmo . vae . model
2024-11-22 13:44:42 +00:00
import comfy . ldm . lightricks . vae . causal_video_autoencoder
2023-03-13 18:49:18 +00:00
import yaml
2023-02-16 15:38:08 +00:00
2023-08-25 21:25:39 +00:00
import comfy . utils
2023-04-02 03:19:15 +00:00
from . import clip_vision
2023-04-19 13:36:19 +00:00
from . import gligen
2023-05-28 06:02:09 +00:00
from . import diffusers_convert
2023-06-22 17:03:50 +00:00
from . import model_detection
2023-02-03 07:06:34 +00:00
2023-06-22 17:03:50 +00:00
from . import sd1_clip
2023-06-25 05:40:38 +00:00
from . import sdxl_clip
2024-07-28 05:19:20 +00:00
import comfy . text_encoders . sd2_clip
2024-07-15 21:36:24 +00:00
import comfy . text_encoders . sd3_clip
import comfy . text_encoders . sa_t5
2024-07-11 20:51:06 +00:00
import comfy . text_encoders . aura_t5
2024-07-25 22:21:08 +00:00
import comfy . text_encoders . hydit
2024-08-01 08:03:59 +00:00
import comfy . text_encoders . flux
2024-08-20 14:42:40 +00:00
import comfy . text_encoders . long_clipl
2024-10-26 10:54:00 +00:00
import comfy . text_encoders . genmo
2024-11-22 13:44:42 +00:00
import comfy . text_encoders . lt
2023-06-09 16:24:24 +00:00
2023-08-28 18:49:18 +00:00
import comfy . model_patcher
2023-08-25 21:11:51 +00:00
import comfy . lora
2024-11-21 13:38:23 +00:00
import comfy . lora_convert
2023-08-25 21:25:39 +00:00
import comfy . t2i_adapter . adapter
2023-11-21 17:54:19 +00:00
import comfy . taesd . taesd
2023-08-25 21:11:51 +00:00
2024-11-21 13:38:23 +00:00
import comfy . ldm . flux . redux
2023-06-30 03:40:02 +00:00
def load_lora_for_models ( model , clip , lora , strength_model , strength_clip ) :
2023-11-02 00:27:20 +00:00
key_map = { }
if model is not None :
key_map = comfy . lora . model_lora_keys_unet ( model . model , key_map )
if clip is not None :
key_map = comfy . lora . model_lora_keys_clip ( clip . cond_stage_model , key_map )
2024-11-21 13:38:23 +00:00
lora = comfy . lora_convert . convert_lora ( lora )
2023-08-25 21:11:51 +00:00
loaded = comfy . lora . load_lora ( lora , key_map )
2023-11-02 00:27:20 +00:00
if model is not None :
new_modelpatcher = model . clone ( )
k = new_modelpatcher . add_patches ( loaded , strength_model )
else :
k = ( )
new_modelpatcher = None
if clip is not None :
new_clip = clip . clone ( )
k1 = new_clip . add_patches ( loaded , strength_clip )
else :
k1 = ( )
new_clip = None
2023-02-03 07:06:34 +00:00
k = set ( k )
k1 = set ( k1 )
for x in loaded :
if ( x not in k ) and ( x not in k1 ) :
2024-03-10 15:37:08 +00:00
logging . warning ( " NOT LOADED {} " . format ( x ) )
2023-02-03 07:06:34 +00:00
return ( new_modelpatcher , new_clip )
2023-01-03 06:53:32 +00:00
class CLIP :
2024-08-17 14:15:13 +00:00
def __init__ ( self , target = None , embedding_directory = None , no_init = False , tokenizer_data = { } , parameters = 0 , model_options = { } ) :
2023-02-03 07:06:34 +00:00
if no_init :
return
2023-07-03 20:09:02 +00:00
params = target . params . copy ( )
2023-06-22 17:03:50 +00:00
clip = target . clip
tokenizer = target . tokenizer
2023-01-29 23:46:44 +00:00
2024-09-17 07:49:54 +00:00
load_device = model_options . get ( " load_device " , model_management . text_encoder_device ( ) )
offload_device = model_options . get ( " offload_device " , model_management . text_encoder_offload_device ( ) )
2024-08-17 14:15:13 +00:00
dtype = model_options . get ( " dtype " , None )
if dtype is None :
dtype = model_management . text_encoder_dtype ( load_device )
2024-06-11 21:03:26 +00:00
params [ ' dtype ' ] = dtype
2024-09-17 07:49:54 +00:00
params [ ' device ' ] = model_options . get ( " initial_device " , model_management . text_encoder_initial_device ( load_device , offload_device , parameters * model_management . dtype_size ( dtype ) ) )
2024-08-17 14:15:13 +00:00
params [ ' model_options ' ] = model_options
2023-08-24 01:01:15 +00:00
self . cond_stage_model = clip ( * * ( params ) )
2023-06-15 19:21:37 +00:00
2024-06-11 21:03:26 +00:00
for dt in self . cond_stage_model . dtypes :
if not model_management . supports_cast ( load_device , dt ) :
load_device = offload_device
2024-08-12 04:23:29 +00:00
if params [ ' device ' ] != offload_device :
self . cond_stage_model . to ( offload_device )
logging . warning ( " Had to shift TE back. " )
2024-06-11 21:03:26 +00:00
2024-07-24 20:43:53 +00:00
self . tokenizer = tokenizer ( embedding_directory = embedding_directory , tokenizer_data = tokenizer_data )
2023-08-28 18:49:18 +00:00
self . patcher = comfy . model_patcher . ModelPatcher ( self . cond_stage_model , load_device = load_device , offload_device = offload_device )
2024-08-12 04:06:01 +00:00
if params [ ' device ' ] == load_device :
2024-08-13 03:42:21 +00:00
model_management . load_models_gpu ( [ self . patcher ] , force_full_load = True )
2023-03-06 16:34:02 +00:00
self . layer_idx = None
2024-08-12 03:50:01 +00:00
logging . debug ( " CLIP model load device: {} , offload device: {} , current: {} " . format ( load_device , offload_device , params [ ' device ' ] ) )
2023-02-03 07:06:34 +00:00
def clone ( self ) :
n = CLIP ( no_init = True )
n . patcher = self . patcher . clone ( )
n . cond_stage_model = self . cond_stage_model
n . tokenizer = self . tokenizer
2023-03-03 18:04:36 +00:00
n . layer_idx = self . layer_idx
2023-02-03 07:06:34 +00:00
return n
2023-07-14 06:37:30 +00:00
def add_patches ( self , patches , strength_patch = 1.0 , strength_model = 1.0 ) :
return self . patcher . add_patches ( patches , strength_patch , strength_model )
2023-01-03 06:53:32 +00:00
2023-02-05 20:20:18 +00:00
def clip_layer ( self , layer_idx ) :
2023-03-03 18:04:36 +00:00
self . layer_idx = layer_idx
2023-02-05 20:20:18 +00:00
2023-04-14 19:16:55 +00:00
def tokenize ( self , text , return_word_ids = False ) :
return self . tokenizer . tokenize_with_weights ( text , return_word_ids )
2023-04-13 20:06:50 +00:00
2024-07-11 00:06:50 +00:00
def encode_from_tokens ( self , tokens , return_pooled = False , return_dict = False ) :
2024-02-25 12:20:31 +00:00
self . cond_stage_model . reset_clip_options ( )
2023-03-06 16:34:02 +00:00
if self . layer_idx is not None :
2024-02-25 12:20:31 +00:00
self . cond_stage_model . set_clip_options ( { " layer " : self . layer_idx } )
if return_pooled == " unprojected " :
self . cond_stage_model . set_clip_options ( { " projected_pooled " : False } )
2023-07-01 17:22:51 +00:00
2023-08-17 14:58:59 +00:00
self . load_model ( )
2024-07-11 00:06:50 +00:00
o = self . cond_stage_model . encode_token_weights ( tokens )
cond , pooled = o [ : 2 ]
if return_dict :
out = { " cond " : cond , " pooled_output " : pooled }
if len ( o ) > 2 :
for k in o [ 2 ] :
out [ k ] = o [ 2 ] [ k ]
return out
2023-04-19 13:36:19 +00:00
if return_pooled :
2023-07-01 17:22:51 +00:00
return cond , pooled
return cond
2023-01-03 06:53:32 +00:00
2023-04-15 22:46:58 +00:00
def encode ( self , text ) :
2023-04-15 22:55:17 +00:00
tokens = self . tokenize ( text )
2023-04-15 22:46:58 +00:00
return self . encode_from_tokens ( tokens )
2024-02-19 15:29:18 +00:00
def load_sd ( self , sd , full_model = False ) :
if full_model :
return self . cond_stage_model . load_state_dict ( sd , strict = False )
else :
return self . cond_stage_model . load_sd ( sd )
2023-06-22 17:03:50 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
2024-07-25 14:52:09 +00:00
sd_clip = self . cond_stage_model . state_dict ( )
sd_tokenizer = self . tokenizer . state_dict ( )
for k in sd_tokenizer :
sd_clip [ k ] = sd_tokenizer [ k ]
return sd_clip
2023-06-26 16:21:07 +00:00
2023-08-17 14:58:59 +00:00
def load_model ( self ) :
model_management . load_model_gpu ( self . patcher )
return self . patcher
2023-06-26 16:21:07 +00:00
2023-07-14 06:37:30 +00:00
def get_key_patches ( self ) :
return self . patcher . get_key_patches ( )
2023-01-03 06:53:32 +00:00
class VAE :
2024-06-16 06:04:24 +00:00
def __init__ ( self , sd = None , device = None , config = None , dtype = None ) :
2023-10-17 18:51:51 +00:00
if ' decoder.up_blocks.0.resnets.0.norm1.weight ' in sd . keys ( ) : #diffusers format
sd = diffusers_convert . convert_vae_state_dict ( sd )
2023-11-22 23:16:02 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1767 * shape [ 2 ] * shape [ 3 ] ) * model_management . dtype_size ( dtype ) #These are for AutoencoderKL and need tweaking (should be lower)
self . memory_used_decode = lambda shape , dtype : ( 2178 * shape [ 2 ] * shape [ 3 ] * 64 ) * model_management . dtype_size ( dtype )
2024-01-02 18:24:34 +00:00
self . downscale_ratio = 8
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 8
2024-06-16 06:04:24 +00:00
self . latent_channels = 4
2024-11-01 21:33:09 +00:00
self . latent_dim = 2
2024-06-15 16:14:56 +00:00
self . output_channels = 3
2024-02-16 11:30:39 +00:00
self . process_input = lambda image : image * 2.0 - 1.0
self . process_output = lambda image : torch . clamp ( ( image + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2024-06-16 17:12:54 +00:00
self . working_dtypes = [ torch . bfloat16 , torch . float32 ]
2023-11-21 17:54:19 +00:00
2023-01-03 06:53:32 +00:00
if config is None :
2023-11-24 00:41:33 +00:00
if " decoder.mid.block_1.mix_factor " in sd :
encoder_config = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
decoder_config = encoder_config . copy ( )
decoder_config [ " video_kernel_size " ] = [ 3 , 1 , 1 ]
decoder_config [ " alpha " ] = 0.0
self . first_stage_model = AutoencodingEngine ( regularizer_config = { ' target ' : " comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer " } ,
encoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Encoder " , ' params ' : encoder_config } ,
decoder_config = { ' target ' : " comfy.ldm.modules.temporal_ae.VideoDecoder " , ' params ' : decoder_config } )
elif " taesd_decoder.1.weight " in sd :
2024-06-16 07:10:04 +00:00
self . latent_channels = sd [ " taesd_decoder.1.weight " ] . shape [ 1 ]
self . first_stage_model = comfy . taesd . taesd . TAESD ( latent_channels = self . latent_channels )
2024-02-16 11:30:39 +00:00
elif " vquantizer.codebook.weight " in sd : #VQGan: stage a of stable cascade
self . first_stage_model = StageA ( )
self . downscale_ratio = 4
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 4
2024-02-16 11:30:39 +00:00
#TODO
#self.memory_used_encode
#self.memory_used_decode
self . process_input = lambda image : image
self . process_output = lambda image : image
2024-02-19 09:06:49 +00:00
elif " backbone.1.0.block.0.1.num_batches_tracked " in sd : #effnet: encoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " encoder. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " blocks.11.num_batches_tracked " in sd : #previewer: decoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " previewer. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " encoder.backbone.1.0.block.0.1.num_batches_tracked " in sd : #combined effnet and previewer for stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
2024-04-24 13:20:31 +00:00
elif " decoder.conv_in.weight " in sd :
2023-11-21 17:54:19 +00:00
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
2024-01-03 08:30:39 +00:00
2024-04-19 01:05:33 +00:00
if ' encoder.down.2.downsample.conv.weight ' not in sd and ' decoder.up.3.upsample.conv.weight ' not in sd : #Stable diffusion x4 upscaler VAE
2024-01-03 08:30:39 +00:00
ddconfig [ ' ch_mult ' ] = [ 1 , 2 , 4 ]
self . downscale_ratio = 4
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 4
2024-01-03 08:30:39 +00:00
2024-04-19 01:05:33 +00:00
self . latent_channels = ddconfig [ ' z_channels ' ] = sd [ " decoder.conv_in.weight " ] . shape [ 1 ]
if ' quant_conv.weight ' in sd :
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
else :
self . first_stage_model = AutoencodingEngine ( regularizer_config = { ' target ' : " comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer " } ,
encoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Encoder " , ' params ' : ddconfig } ,
decoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Decoder " , ' params ' : ddconfig } )
2024-06-27 15:06:52 +00:00
elif " decoder.layers.1.layers.0.beta " in sd :
2024-06-15 16:14:56 +00:00
self . first_stage_model = AudioOobleckVAE ( )
2024-06-16 15:47:32 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1000 * shape [ 2 ] ) * model_management . dtype_size ( dtype )
self . memory_used_decode = lambda shape , dtype : ( 1000 * shape [ 2 ] * 2048 ) * model_management . dtype_size ( dtype )
2024-06-15 16:14:56 +00:00
self . latent_channels = 64
self . output_channels = 2
self . upscale_ratio = 2048
self . downscale_ratio = 2048
2024-11-01 21:33:09 +00:00
self . latent_dim = 1
2024-06-15 16:14:56 +00:00
self . process_output = lambda audio : audio
self . process_input = lambda audio : audio
2024-06-16 17:12:54 +00:00
self . working_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ]
2024-11-05 08:42:58 +00:00
elif " blocks.2.blocks.3.stack.5.weight " in sd or " decoder.blocks.2.blocks.3.stack.5.weight " in sd or " layers.4.layers.1.attn_block.attn.qkv.weight " in sd or " encoder.layers.4.layers.1.attn_block.attn.qkv.weight " in sd : #genmo mochi vae
2024-10-26 10:54:00 +00:00
if " blocks.2.blocks.3.stack.5.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " " : " decoder. " } )
2024-11-01 21:33:09 +00:00
if " layers.4.layers.1.attn_block.attn.qkv.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " " : " encoder. " } )
2024-10-26 10:54:00 +00:00
self . first_stage_model = comfy . ldm . genmo . vae . model . VideoVAE ( )
self . latent_channels = 12
2024-11-01 21:33:09 +00:00
self . latent_dim = 3
2024-10-26 10:54:00 +00:00
self . memory_used_decode = lambda shape , dtype : ( 1000 * shape [ 2 ] * shape [ 3 ] * shape [ 4 ] * ( 6 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
2024-11-01 21:33:09 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1.5 * max ( shape [ 2 ] , 7 ) * shape [ 3 ] * shape [ 4 ] * ( 6 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
2024-10-26 10:54:00 +00:00
self . upscale_ratio = ( lambda a : max ( 0 , a * 6 - 5 ) , 8 , 8 )
2024-11-01 21:33:09 +00:00
self . working_dtypes = [ torch . float16 , torch . float32 ]
2024-11-22 13:44:42 +00:00
elif " decoder.up_blocks.0.res_blocks.0.conv1.conv.weight " in sd : #lightricks ltxv
self . first_stage_model = comfy . ldm . lightricks . vae . causal_video_autoencoder . VideoVAE ( )
self . latent_channels = 128
self . latent_dim = 3
self . memory_used_decode = lambda shape , dtype : ( 900 * shape [ 2 ] * shape [ 3 ] * shape [ 4 ] * ( 8 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
self . memory_used_encode = lambda shape , dtype : ( 70 * max ( shape [ 2 ] , 7 ) * shape [ 3 ] * shape [ 4 ] ) * model_management . dtype_size ( dtype )
2024-11-22 23:00:34 +00:00
self . upscale_ratio = ( lambda a : max ( 0 , a * 8 - 7 ) , 32 , 32 )
2024-11-22 13:44:42 +00:00
self . working_dtypes = [ torch . bfloat16 , torch . float32 ]
2024-04-24 13:20:31 +00:00
else :
logging . warning ( " WARNING: No VAE weights detected, VAE not initalized. " )
self . first_stage_model = None
return
2023-01-03 06:53:32 +00:00
else :
2023-05-28 06:02:09 +00:00
self . first_stage_model = AutoencoderKL ( * * ( config [ ' params ' ] ) )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . eval ( )
2023-10-17 18:51:51 +00:00
m , u = self . first_stage_model . load_state_dict ( sd , strict = False )
if len ( m ) > 0 :
2024-03-10 15:37:08 +00:00
logging . warning ( " Missing VAE keys {} " . format ( m ) )
2023-10-17 18:51:51 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " Leftover VAE keys {} " . format ( u ) )
2023-05-28 06:02:09 +00:00
2023-03-06 15:50:50 +00:00
if device is None :
2023-07-01 19:22:40 +00:00
device = model_management . vae_device ( )
2023-01-03 06:53:32 +00:00
self . device = device
2023-11-28 09:58:32 +00:00
offload_device = model_management . vae_offload_device ( )
2023-12-12 17:03:29 +00:00
if dtype is None :
2024-06-16 17:12:54 +00:00
dtype = model_management . vae_dtype ( self . device , self . working_dtypes )
2023-12-12 17:03:29 +00:00
self . vae_dtype = dtype
2023-07-06 22:04:28 +00:00
self . first_stage_model . to ( self . vae_dtype )
2023-12-08 07:35:45 +00:00
self . output_device = model_management . intermediate_device ( )
2023-01-03 06:53:32 +00:00
2023-11-28 09:58:32 +00:00
self . patcher = comfy . model_patcher . ModelPatcher ( self . first_stage_model , load_device = self . device , offload_device = offload_device )
2024-06-16 17:12:54 +00:00
logging . debug ( " VAE load device: {} , offload device: {} , dtype: {} " . format ( self . device , offload_device , self . vae_dtype ) )
2023-11-28 09:58:32 +00:00
2024-02-19 09:06:49 +00:00
def vae_encode_crop_pixels ( self , pixels ) :
2024-06-15 16:14:56 +00:00
dims = pixels . shape [ 1 : - 1 ]
for d in range ( len ( dims ) ) :
x = ( dims [ d ] / / self . downscale_ratio ) * self . downscale_ratio
x_offset = ( dims [ d ] % self . downscale_ratio ) / / 2
if x != dims [ d ] :
pixels = pixels . narrow ( d + 1 , x_offset , x )
2024-02-19 09:06:49 +00:00
return pixels
2023-03-22 18:49:00 +00:00
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-08-25 21:25:39 +00:00
steps = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-04-24 10:55:44 +00:00
2024-02-16 11:30:39 +00:00
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
output = self . process_output (
2024-02-19 09:06:49 +00:00
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) )
2024-02-16 11:30:39 +00:00
/ 3.0 )
2023-03-22 18:49:00 +00:00
return output
2024-06-22 15:45:58 +00:00
def decode_tiled_1d ( self , samples , tile_x = 128 , overlap = 32 ) :
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return comfy . utils . tiled_scale_multidim ( samples , decode_fn , tile = ( tile_x , ) , overlap = overlap , upscale_amount = self . upscale_ratio , out_channels = self . output_channels , output_device = self . output_device )
2024-06-18 02:48:23 +00:00
2024-10-26 10:54:00 +00:00
def decode_tiled_3d ( self , samples , tile_t = 999 , tile_x = 32 , tile_y = 32 , overlap = ( 1 , 8 , 8 ) ) :
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return self . process_output ( comfy . utils . tiled_scale_multidim ( samples , decode_fn , tile = ( tile_t , tile_x , tile_y ) , overlap = overlap , upscale_amount = self . upscale_ratio , out_channels = self . output_channels , output_device = self . output_device ) )
2023-06-12 03:25:39 +00:00
def encode_tiled_ ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
2023-08-25 21:25:39 +00:00
steps = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-06-12 03:25:39 +00:00
2024-02-16 11:30:39 +00:00
encode_fn = lambda a : self . first_stage_model . encode ( ( self . process_input ( a ) ) . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
2024-01-02 18:24:34 +00:00
samples = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
2023-06-12 03:25:39 +00:00
samples / = 3.0
return samples
2024-06-22 15:45:58 +00:00
def encode_tiled_1d ( self , samples , tile_x = 128 * 2048 , overlap = 32 * 2048 ) :
encode_fn = lambda a : self . first_stage_model . encode ( ( self . process_input ( a ) ) . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return comfy . utils . tiled_scale_multidim ( samples , encode_fn , tile = ( tile_x , ) , overlap = overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device )
2023-03-22 18:49:00 +00:00
def decode ( self , samples_in ) :
2024-10-26 10:54:00 +00:00
pixel_samples = None
2023-03-22 18:49:00 +00:00
try :
2023-11-22 23:16:02 +00:00
memory_used = self . memory_used_decode ( samples_in . shape , self . vae_dtype )
2023-11-28 09:58:32 +00:00
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2023-03-29 06:24:37 +00:00
free_memory = model_management . get_free_memory ( self . device )
2023-08-17 05:06:34 +00:00
batch_number = int ( free_memory / memory_used )
2023-03-29 06:24:37 +00:00
batch_number = max ( 1 , batch_number )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
2023-07-06 22:04:28 +00:00
samples = samples_in [ x : x + batch_number ] . to ( self . vae_dtype ) . to ( self . device )
2024-10-26 10:54:00 +00:00
out = self . process_output ( self . first_stage_model . decode ( samples ) . to ( self . output_device ) . float ( ) )
if pixel_samples is None :
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , ) + tuple ( out . shape [ 1 : ] ) , device = self . output_device )
pixel_samples [ x : x + batch_number ] = out
2023-03-22 18:49:00 +00:00
except model_management . OOM_EXCEPTION as e :
2024-03-10 15:37:08 +00:00
logging . warning ( " Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding. " )
2024-10-26 10:54:00 +00:00
dims = samples_in . ndim - 2
if dims == 1 :
2024-06-18 02:48:23 +00:00
pixel_samples = self . decode_tiled_1d ( samples_in )
2024-10-26 10:54:00 +00:00
elif dims == 2 :
2024-06-18 02:48:23 +00:00
pixel_samples = self . decode_tiled_ ( samples_in )
2024-10-26 10:54:00 +00:00
elif dims == 3 :
2024-11-22 23:00:34 +00:00
tile = 256 / / self . spacial_compression_decode ( )
overlap = tile / / 4
pixel_samples = self . decode_tiled_3d ( samples_in , tile_x = tile , tile_y = tile , overlap = ( 1 , overlap , overlap ) )
2023-03-22 18:49:00 +00:00
2023-12-08 07:35:45 +00:00
pixel_samples = pixel_samples . to ( self . output_device ) . movedim ( 1 , - 1 )
2023-01-03 06:53:32 +00:00
return pixel_samples
2024-11-07 08:47:12 +00:00
def decode_tiled ( self , samples , tile_x = None , tile_y = None , overlap = None ) :
2024-11-07 09:01:24 +00:00
memory_used = self . memory_used_decode ( samples . shape , self . vae_dtype ) #TODO: calculate mem required for tile
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2024-11-07 08:47:12 +00:00
dims = samples . ndim - 2
args = { }
if tile_x is not None :
args [ " tile_x " ] = tile_x
if tile_y is not None :
args [ " tile_y " ] = tile_y
if overlap is not None :
args [ " overlap " ] = overlap
if dims == 1 :
args . pop ( " tile_y " )
output = self . decode_tiled_1d ( samples , * * args )
elif dims == 2 :
output = self . decode_tiled_ ( samples , * * args )
elif dims == 3 :
output = self . decode_tiled_3d ( samples , * * args )
return output . movedim ( 1 , - 1 )
2023-02-24 07:10:10 +00:00
2023-01-03 06:53:32 +00:00
def encode ( self , pixel_samples ) :
2024-02-19 09:06:49 +00:00
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
2024-11-01 21:33:09 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
if self . latent_dim == 3 :
pixel_samples = pixel_samples . movedim ( 1 , 0 ) . unsqueeze ( 0 )
2023-06-12 03:25:39 +00:00
try :
2023-11-22 23:16:02 +00:00
memory_used = self . memory_used_encode ( pixel_samples . shape , self . vae_dtype )
2023-11-28 09:58:32 +00:00
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2023-06-12 04:21:50 +00:00
free_memory = model_management . get_free_memory ( self . device )
2024-10-10 03:34:34 +00:00
batch_number = int ( free_memory / max ( 1 , memory_used ) )
2023-06-12 04:21:50 +00:00
batch_number = max ( 1 , batch_number )
2024-11-01 21:33:09 +00:00
samples = None
2023-06-12 03:25:39 +00:00
for x in range ( 0 , pixel_samples . shape [ 0 ] , batch_number ) :
2024-11-01 21:33:09 +00:00
pixels_in = self . process_input ( pixel_samples [ x : x + batch_number ] ) . to ( self . vae_dtype ) . to ( self . device )
out = self . first_stage_model . encode ( pixels_in ) . to ( self . output_device ) . float ( )
if samples is None :
samples = torch . empty ( ( pixel_samples . shape [ 0 ] , ) + tuple ( out . shape [ 1 : ] ) , device = self . output_device )
samples [ x : x + batch_number ] = out
2023-06-12 04:21:50 +00:00
2023-06-12 03:25:39 +00:00
except model_management . OOM_EXCEPTION as e :
2024-03-10 15:37:08 +00:00
logging . warning ( " Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding. " )
2024-06-22 15:45:58 +00:00
if len ( pixel_samples . shape ) == 3 :
samples = self . encode_tiled_1d ( pixel_samples )
else :
samples = self . encode_tiled_ ( pixel_samples )
2023-06-12 03:25:39 +00:00
2023-01-03 06:53:32 +00:00
return samples
2023-03-11 20:28:15 +00:00
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
2024-02-19 09:06:49 +00:00
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
2023-11-28 09:58:32 +00:00
model_management . load_model_gpu ( self . patcher )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
2023-03-11 20:28:15 +00:00
return samples
2023-02-25 19:57:28 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . first_stage_model . state_dict ( )
2024-11-22 23:00:34 +00:00
def spacial_compression_decode ( self ) :
try :
return self . upscale_ratio [ - 1 ]
except :
return self . upscale_ratio
2023-03-05 23:39:25 +00:00
class StyleModel :
def __init__ ( self , model , device = " cpu " ) :
self . model = model
def get_cond ( self , input ) :
return self . model ( input . last_hidden_state )
def load_style_model ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
model_data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-03-05 23:39:25 +00:00
keys = model_data . keys ( )
if " style_embedding " in keys :
2023-08-25 21:25:39 +00:00
model = comfy . t2i_adapter . adapter . StyleAdapter ( width = 1024 , context_dim = 768 , num_head = 8 , n_layes = 3 , num_token = 8 )
2024-11-21 13:38:23 +00:00
elif " redux_down.weight " in keys :
model = comfy . ldm . flux . redux . ReduxImageEncoder ( )
2023-03-05 23:39:25 +00:00
else :
raise Exception ( " invalid style model {} " . format ( ckpt_path ) )
model . load_state_dict ( model_data )
return StyleModel ( model )
2024-02-16 18:29:04 +00:00
class CLIPType ( Enum ) :
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
2024-06-12 03:27:39 +00:00
SD3 = 3
2024-06-15 16:14:56 +00:00
STABLE_AUDIO = 4
2024-07-25 22:21:08 +00:00
HUNYUAN_DIT = 5
2024-08-01 08:03:59 +00:00
FLUX = 6
2024-10-26 10:54:00 +00:00
MOCHI = 7
2024-11-22 13:44:42 +00:00
LTXV = 8
2023-03-05 23:39:25 +00:00
2024-08-17 14:15:13 +00:00
def load_clip ( ckpt_paths , embedding_directory = None , clip_type = CLIPType . STABLE_DIFFUSION , model_options = { } ) :
2023-06-25 05:40:38 +00:00
clip_data = [ ]
for p in ckpt_paths :
2023-08-25 21:25:39 +00:00
clip_data . append ( comfy . utils . load_torch_file ( p , safe_load = True ) )
2024-08-19 21:36:35 +00:00
return load_text_encoder_state_dicts ( clip_data , embedding_directory = embedding_directory , clip_type = clip_type , model_options = model_options )
2023-06-25 05:40:38 +00:00
2024-10-01 11:08:41 +00:00
class TEModel ( Enum ) :
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
def detect_te_model ( sd ) :
if " text_model.encoder.layers.30.mlp.fc1.weight " in sd :
return TEModel . CLIP_G
if " text_model.encoder.layers.22.mlp.fc1.weight " in sd :
return TEModel . CLIP_H
if " text_model.encoder.layers.0.mlp.fc1.weight " in sd :
return TEModel . CLIP_L
if " encoder.block.23.layer.1.DenseReluDense.wi_1.weight " in sd :
weight = sd [ " encoder.block.23.layer.1.DenseReluDense.wi_1.weight " ]
if weight . shape [ - 1 ] == 4096 :
return TEModel . T5_XXL
elif weight . shape [ - 1 ] == 2048 :
return TEModel . T5_XL
if " encoder.block.0.layer.0.SelfAttention.k.weight " in sd :
return TEModel . T5_BASE
return None
2024-10-10 19:06:15 +00:00
2024-10-21 02:27:00 +00:00
def t5xxl_detect ( clip_data ) :
2024-10-10 19:06:15 +00:00
weight_name = " encoder.block.23.layer.1.DenseReluDense.wi_1.weight "
for sd in clip_data :
2024-10-21 02:27:00 +00:00
if weight_name in sd :
return comfy . text_encoders . sd3_clip . t5_xxl_detect ( sd )
return { }
2024-10-10 19:06:15 +00:00
2024-08-19 21:36:35 +00:00
def load_text_encoder_state_dicts ( state_dicts = [ ] , embedding_directory = None , clip_type = CLIPType . STABLE_DIFFUSION , model_options = { } ) :
clip_data = state_dicts
2024-10-09 23:43:17 +00:00
2023-06-24 17:56:46 +00:00
class EmptyClass :
pass
2023-06-25 05:40:38 +00:00
for i in range ( len ( clip_data ) ) :
if " transformer.resblocks.0.ln_1.weight " in clip_data [ i ] :
2024-02-25 06:41:08 +00:00
clip_data [ i ] = comfy . utils . clip_text_transformers_convert ( clip_data [ i ] , " " , " " )
2024-02-25 13:29:12 +00:00
else :
if " text_projection " in clip_data [ i ] :
clip_data [ i ] [ " text_projection.weight " ] = clip_data [ i ] [ " text_projection " ] . transpose ( 0 , 1 ) #old models saved with the CLIPSave node
2023-06-25 05:40:38 +00:00
2023-06-24 17:56:46 +00:00
clip_target = EmptyClass ( )
clip_target . params = { }
2023-06-25 05:40:38 +00:00
if len ( clip_data ) == 1 :
2024-10-01 11:08:41 +00:00
te_model = detect_te_model ( clip_data [ 0 ] )
if te_model == TEModel . CLIP_G :
2024-02-16 18:29:04 +00:00
if clip_type == CLIPType . STABLE_CASCADE :
clip_target . clip = sdxl_clip . StableCascadeClipModel
clip_target . tokenizer = sdxl_clip . StableCascadeTokenizer
2024-10-02 08:25:17 +00:00
elif clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = False , clip_g = True , t5 = False )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-02-16 18:29:04 +00:00
else :
clip_target . clip = sdxl_clip . SDXLRefinerClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . CLIP_H :
2024-07-28 05:19:20 +00:00
clip_target . clip = comfy . text_encoders . sd2_clip . SD2ClipModel
clip_target . tokenizer = comfy . text_encoders . sd2_clip . SD2Tokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . T5_XXL :
2024-10-26 10:54:00 +00:00
if clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = False , clip_g = False , t5 = True , * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-11-22 13:44:42 +00:00
elif clip_type == CLIPType . LTXV :
clip_target . clip = comfy . text_encoders . lt . ltxv_te ( * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . lt . LTXVT5Tokenizer
2024-10-26 10:54:00 +00:00
else : #CLIPType.MOCHI
clip_target . clip = comfy . text_encoders . genmo . mochi_te ( * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . genmo . MochiT5Tokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . T5_XL :
clip_target . clip = comfy . text_encoders . aura_t5 . AuraT5Model
clip_target . tokenizer = comfy . text_encoders . aura_t5 . AuraT5Tokenizer
elif te_model == TEModel . T5_BASE :
2024-07-15 21:36:24 +00:00
clip_target . clip = comfy . text_encoders . sa_t5 . SAT5Model
clip_target . tokenizer = comfy . text_encoders . sa_t5 . SAT5Tokenizer
2023-06-25 05:40:38 +00:00
else :
2024-10-02 08:25:17 +00:00
if clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = True , clip_g = False , t5 = False )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
else :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
2024-06-10 17:26:25 +00:00
elif len ( clip_data ) == 2 :
2024-06-12 03:27:39 +00:00
if clip_type == CLIPType . SD3 :
2024-10-03 13:26:11 +00:00
te_models = [ detect_te_model ( clip_data [ 0 ] ) , detect_te_model ( clip_data [ 1 ] ) ]
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = TEModel . CLIP_L in te_models , clip_g = TEModel . CLIP_G in te_models , t5 = TEModel . T5_XXL in te_models , * * t5xxl_detect ( clip_data ) )
2024-07-15 21:36:24 +00:00
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-07-25 22:21:08 +00:00
elif clip_type == CLIPType . HUNYUAN_DIT :
clip_target . clip = comfy . text_encoders . hydit . HyditModel
clip_target . tokenizer = comfy . text_encoders . hydit . HyditTokenizer
2024-08-01 08:03:59 +00:00
elif clip_type == CLIPType . FLUX :
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . flux . flux_clip ( * * t5xxl_detect ( clip_data ) )
2024-08-01 08:03:59 +00:00
clip_target . tokenizer = comfy . text_encoders . flux . FluxTokenizer
2024-06-12 03:27:39 +00:00
else :
clip_target . clip = sdxl_clip . SDXLClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2024-06-10 17:26:25 +00:00
elif len ( clip_data ) == 3 :
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( * * t5xxl_detect ( clip_data ) )
2024-07-15 21:36:24 +00:00
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2023-06-24 17:56:46 +00:00
2024-08-12 04:06:01 +00:00
parameters = 0
2024-09-15 11:59:18 +00:00
tokenizer_data = { }
2024-08-12 04:06:01 +00:00
for c in clip_data :
parameters + = comfy . utils . calculate_parameters ( c )
2024-09-15 11:59:18 +00:00
tokenizer_data , model_options = comfy . text_encoders . long_clipl . model_options_long_clip ( c , tokenizer_data , model_options )
2024-08-12 04:06:01 +00:00
2024-09-15 11:59:18 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory , parameters = parameters , tokenizer_data = tokenizer_data , model_options = model_options )
2023-06-25 05:40:38 +00:00
for c in clip_data :
m , u = clip . load_sd ( c )
if len ( m ) > 0 :
2024-03-10 15:37:08 +00:00
logging . warning ( " clip missing: {} " . format ( m ) )
2023-06-25 05:40:38 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " clip unexpected: {} " . format ( u ) )
2023-02-05 20:20:18 +00:00
return clip
2023-01-03 06:53:32 +00:00
2023-04-19 13:36:19 +00:00
def load_gligen ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-04-19 13:36:19 +00:00
model = gligen . load_gligen ( data )
if model_management . should_use_fp16 ( ) :
model = model . half ( )
2023-08-28 18:49:18 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = model_management . unet_offload_device ( ) )
2023-04-19 13:36:19 +00:00
2023-06-09 16:24:24 +00:00
def load_checkpoint ( config_path = None , ckpt_path = None , output_vae = True , output_clip = True , embedding_directory = None , state_dict = None , config = None ) :
2024-05-07 00:04:39 +00:00
logging . warning ( " Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one. " )
model , clip , vae , _ = load_checkpoint_guess_config ( ckpt_path , output_vae = output_vae , output_clip = output_clip , output_clipvision = False , embedding_directory = embedding_directory , output_model = True )
2023-06-23 06:14:12 +00:00
#TODO: this function is a mess and should be removed eventually
2023-06-09 16:24:24 +00:00
if config is None :
with open ( config_path , ' r ' ) as stream :
config = yaml . safe_load ( stream )
2023-01-03 06:53:32 +00:00
model_config_params = config [ ' model ' ] [ ' params ' ]
clip_config = model_config_params [ ' cond_stage_config ' ]
scale_factor = model_config_params [ ' scale_factor ' ]
2023-06-09 16:24:24 +00:00
if " parameterization " in model_config_params :
if model_config_params [ " parameterization " ] == " v " :
2024-05-07 00:04:39 +00:00
m = model . clone ( )
class ModelSamplingAdvanced ( comfy . model_sampling . ModelSamplingDiscrete , comfy . model_sampling . V_PREDICTION ) :
pass
m . add_object_patch ( " model_sampling " , ModelSamplingAdvanced ( model . model . model_config ) )
model = m
2023-08-30 03:58:32 +00:00
2024-05-07 00:04:39 +00:00
layer_idx = clip_config . get ( " params " , { } ) . get ( " layer_idx " , None )
if layer_idx is not None :
clip . clip_layer ( layer_idx )
2023-06-22 17:03:50 +00:00
2024-05-07 00:04:39 +00:00
return ( model , clip , vae )
2023-03-03 08:37:35 +00:00
2024-08-17 14:15:13 +00:00
def load_checkpoint_guess_config ( ckpt_path , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True , model_options = { } , te_model_options = { } ) :
2023-08-25 21:25:39 +00:00
sd = comfy . utils . load_torch_file ( ckpt_path )
2024-08-17 14:15:13 +00:00
out = load_state_dict_guess_config ( sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options = te_model_options )
2024-08-11 12:37:35 +00:00
if out is None :
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( ckpt_path ) )
return out
2024-08-11 12:36:52 +00:00
2024-08-17 14:15:13 +00:00
def load_state_dict_guess_config ( sd , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True , model_options = { } , te_model_options = { } ) :
2023-03-03 08:37:35 +00:00
clip = None
2023-04-02 03:19:15 +00:00
clipvision = None
2023-03-03 08:37:35 +00:00
vae = None
2023-06-22 17:03:50 +00:00
model = None
2023-10-06 17:48:18 +00:00
model_patcher = None
2023-03-03 08:37:35 +00:00
2024-06-15 16:14:56 +00:00
diffusion_model_prefix = model_detection . unet_prefix_from_state_dict ( sd )
parameters = comfy . utils . calculate_parameters ( sd , diffusion_model_prefix )
2024-08-03 17:45:19 +00:00
weight_dtype = comfy . utils . weight_dtype ( sd , diffusion_model_prefix )
2023-12-11 23:24:44 +00:00
load_device = model_management . get_torch_device ( )
2023-03-03 16:07:10 +00:00
2024-06-15 16:14:56 +00:00
model_config = model_detection . model_config_from_unet ( sd , diffusion_model_prefix )
2024-07-11 15:46:51 +00:00
if model_config is None :
2024-08-11 12:37:35 +00:00
return None
2024-07-11 15:46:51 +00:00
2024-08-03 19:06:40 +00:00
unet_weight_dtype = list ( model_config . supported_inference_dtypes )
2024-10-21 22:12:51 +00:00
if weight_dtype is not None and model_config . scaled_fp8 is None :
2024-08-03 19:06:40 +00:00
unet_weight_dtype . append ( weight_dtype )
2024-08-11 12:50:34 +00:00
model_config . custom_operations = model_options . get ( " custom_operations " , None )
2024-10-12 00:51:19 +00:00
unet_dtype = model_options . get ( " dtype " , model_options . get ( " weight_dtype " , None ) )
2024-08-11 12:50:34 +00:00
if unet_dtype is None :
unet_dtype = model_management . unet_dtype ( model_params = parameters , supported_dtypes = unet_weight_dtype )
2024-02-16 15:55:08 +00:00
manual_cast_dtype = model_management . unet_manual_cast ( unet_dtype , load_device , model_config . supported_inference_dtypes )
model_config . set_inference_dtype ( unet_dtype , manual_cast_dtype )
2023-12-11 23:24:44 +00:00
2023-06-22 17:03:50 +00:00
if model_config . clip_vision_prefix is not None :
2023-04-02 03:19:15 +00:00
if output_clipvision :
2023-06-23 05:08:05 +00:00
clipvision = clip_vision . load_clipvision_from_sd ( sd , model_config . clip_vision_prefix , True )
2023-03-03 08:37:35 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
2023-10-13 18:35:21 +00:00
inital_load_device = model_management . unet_inital_load_device ( parameters , unet_dtype )
2024-06-15 16:14:56 +00:00
model = model_config . get_model ( sd , diffusion_model_prefix , device = inital_load_device )
model . load_model_weights ( sd , diffusion_model_prefix )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if output_vae :
2024-01-30 07:24:38 +00:00
vae_sd = comfy . utils . state_dict_prefix_replace ( sd , { k : " " for k in model_config . vae_key_prefix } , filter_keys = True )
2023-11-21 21:29:18 +00:00
vae_sd = model_config . process_vae_state_dict ( vae_sd )
2023-10-17 18:51:51 +00:00
vae = VAE ( sd = vae_sd )
2023-03-03 08:37:35 +00:00
2023-06-22 17:03:50 +00:00
if output_clip :
2024-06-11 17:14:43 +00:00
clip_target = model_config . clip_target ( state_dict = sd )
2023-10-18 23:48:36 +00:00
if clip_target is not None :
2024-02-19 15:29:18 +00:00
clip_sd = model_config . process_clip_state_dict ( sd )
if len ( clip_sd ) > 0 :
2024-08-12 03:50:01 +00:00
parameters = comfy . utils . calculate_parameters ( clip_sd )
2024-08-17 14:15:13 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory , tokenizer_data = clip_sd , parameters = parameters , model_options = te_model_options )
2024-02-19 15:29:18 +00:00
m , u = clip . load_sd ( clip_sd , full_model = True )
if len ( m ) > 0 :
2024-05-09 08:39:46 +00:00
m_filter = list ( filter ( lambda a : " .logit_scale " not in a and " .transformer.text_projection.weight " not in a , m ) )
if len ( m_filter ) > 0 :
logging . warning ( " clip missing: {} " . format ( m ) )
else :
logging . debug ( " clip missing: {} " . format ( m ) )
2024-02-19 15:29:18 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " clip unexpected {} : " . format ( u ) )
2024-02-13 05:01:08 +00:00
else :
2024-03-10 15:37:08 +00:00
logging . warning ( " no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded. " )
2023-06-09 16:24:24 +00:00
2023-06-22 17:03:50 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " left over keys: {} " . format ( left_over ) )
2023-06-14 16:48:02 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
2024-08-06 17:27:48 +00:00
model_patcher = comfy . model_patcher . ModelPatcher ( model , load_device = load_device , offload_device = model_management . unet_offload_device ( ) )
2023-10-06 17:48:18 +00:00
if inital_load_device != torch . device ( " cpu " ) :
2024-03-11 17:54:56 +00:00
logging . info ( " loaded straight to GPU " )
2024-08-13 03:42:21 +00:00
model_management . load_models_gpu ( [ model_patcher ] , force_full_load = True )
2023-08-17 05:06:34 +00:00
return ( model_patcher , clip , vae , clipvision )
2023-06-26 16:21:07 +00:00
2023-07-05 21:34:45 +00:00
2024-08-13 03:18:54 +00:00
def load_diffusion_model_state_dict ( sd , model_options = { } ) : #load unet in diffusers or regular format
dtype = model_options . get ( " dtype " , None )
2024-07-03 15:34:32 +00:00
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection . unet_prefix_from_state_dict ( sd )
temp_sd = comfy . utils . state_dict_prefix_replace ( sd , { diffusion_model_prefix : " " } , filter_keys = True )
if len ( temp_sd ) > 0 :
sd = temp_sd
2023-08-25 21:25:39 +00:00
parameters = comfy . utils . calculate_parameters ( sd )
2024-10-20 03:47:42 +00:00
weight_dtype = comfy . utils . weight_dtype ( sd )
2023-12-11 23:24:44 +00:00
load_device = model_management . get_torch_device ( )
2024-07-11 15:37:31 +00:00
model_config = model_detection . model_config_from_unet ( sd , " " )
2023-12-11 23:24:44 +00:00
2024-07-11 15:37:31 +00:00
if model_config is not None :
2024-07-03 15:34:32 +00:00
new_sd = sd
2024-07-13 17:51:40 +00:00
else :
2024-06-20 01:46:37 +00:00
new_sd = model_detection . convert_diffusers_mmdit ( sd , " " )
2024-07-13 17:51:40 +00:00
if new_sd is not None : #diffusers mmdit
model_config = model_detection . model_config_from_unet ( new_sd , " " )
if model_config is None :
return None
else : #diffusers unet
model_config = model_detection . model_config_from_diffusers_unet ( sd )
if model_config is None :
return None
diffusers_keys = comfy . utils . unet_to_diffusers ( model_config . unet_config )
new_sd = { }
for k in diffusers_keys :
if k in sd :
new_sd [ diffusers_keys [ k ] ] = sd . pop ( k )
else :
logging . warning ( " {} {} " . format ( diffusers_keys [ k ] , k ) )
2024-02-16 15:55:08 +00:00
2023-07-22 02:58:16 +00:00
offload_device = model_management . unet_offload_device ( )
2024-10-20 03:47:42 +00:00
unet_weight_dtype = list ( model_config . supported_inference_dtypes )
2024-10-21 22:12:51 +00:00
if weight_dtype is not None and model_config . scaled_fp8 is None :
2024-10-20 03:47:42 +00:00
unet_weight_dtype . append ( weight_dtype )
2024-08-01 17:28:41 +00:00
if dtype is None :
2024-10-20 03:47:42 +00:00
unet_dtype = model_management . unet_dtype ( model_params = parameters , supported_dtypes = unet_weight_dtype )
2024-08-01 17:28:41 +00:00
else :
unet_dtype = dtype
2024-02-16 15:55:08 +00:00
manual_cast_dtype = model_management . unet_manual_cast ( unet_dtype , load_device , model_config . supported_inference_dtypes )
model_config . set_inference_dtype ( unet_dtype , manual_cast_dtype )
2024-09-19 09:01:00 +00:00
model_config . custom_operations = model_options . get ( " custom_operations " , model_config . custom_operations )
2024-10-09 23:43:17 +00:00
if model_options . get ( " fp8_optimizations " , False ) :
model_config . optimizations [ " fp8 " ] = True
2023-07-22 02:58:16 +00:00
model = model_config . get_model ( new_sd , " " )
model = model . to ( offload_device )
model . load_model_weights ( new_sd , " " )
2023-11-08 03:15:55 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
2024-03-11 17:54:56 +00:00
logging . info ( " left over keys in unet: {} " . format ( left_over ) )
2023-12-11 23:24:44 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = load_device , offload_device = offload_device )
2023-07-05 21:34:45 +00:00
2024-08-13 03:18:54 +00:00
def load_diffusion_model ( unet_path , model_options = { } ) :
2023-11-27 22:32:07 +00:00
sd = comfy . utils . load_torch_file ( unet_path )
2024-08-13 03:18:54 +00:00
model = load_diffusion_model_state_dict ( sd , model_options = model_options )
2023-11-27 22:32:07 +00:00
if model is None :
2024-03-10 15:37:08 +00:00
logging . error ( " ERROR UNSUPPORTED UNET {} " . format ( unet_path ) )
2023-11-27 22:32:07 +00:00
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( unet_path ) )
return model
2024-08-13 03:18:54 +00:00
def load_unet ( unet_path , dtype = None ) :
print ( " WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model " )
return load_diffusion_model ( unet_path , model_options = { " dtype " : dtype } )
def load_unet_state_dict ( sd , dtype = None ) :
print ( " WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict " )
return load_diffusion_model_state_dict ( sd , model_options = { " dtype " : dtype } )
2024-04-08 04:36:22 +00:00
def save_checkpoint ( output_path , model , clip = None , vae = None , clip_vision = None , metadata = None , extra_keys = { } ) :
2024-01-18 00:37:19 +00:00
clip_sd = None
load_models = [ model ]
if clip is not None :
load_models . append ( clip . load_model ( ) )
clip_sd = clip . get_sd ( )
2024-08-18 01:28:36 +00:00
vae_sd = None
if vae is not None :
vae_sd = vae . get_sd ( )
2024-01-18 00:37:19 +00:00
2024-05-12 01:46:05 +00:00
model_management . load_models_gpu ( load_models , force_patch_weights = True )
2024-01-18 00:37:19 +00:00
clip_vision_sd = clip_vision . get_sd ( ) if clip_vision is not None else None
2024-08-18 01:28:36 +00:00
sd = model . model . state_dict_for_saving ( clip_sd , vae_sd , clip_vision_sd )
2024-04-08 04:36:22 +00:00
for k in extra_keys :
sd [ k ] = extra_keys [ k ]
2024-07-03 00:21:51 +00:00
2024-07-03 00:16:33 +00:00
for k in sd :
2024-07-03 00:21:51 +00:00
t = sd [ k ]
if not t . is_contiguous ( ) :
sd [ k ] = t . contiguous ( )
2024-04-08 04:36:22 +00:00
2023-08-25 21:25:39 +00:00
comfy . utils . save_torch_file ( sd , output_path , metadata = metadata )