2023-01-03 06:53:32 +00:00
import torch
2023-02-17 20:31:38 +00:00
import contextlib
2023-08-22 20:23:54 +00:00
import math
2023-01-03 06:53:32 +00:00
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-03-07 16:00:35 +00:00
from . ldm . util import instantiate_from_config
2023-10-17 18:51:51 +00:00
from . ldm . models . autoencoder import AutoencoderKL , AutoencodingEngine
2023-03-13 18:49:18 +00:00
import yaml
2023-02-16 15:38:08 +00:00
2023-08-25 21:25:39 +00:00
import comfy . utils
2023-04-02 03:19:15 +00:00
from . import clip_vision
2023-04-19 13:36:19 +00:00
from . import gligen
2023-05-28 06:02:09 +00:00
from . import diffusers_convert
2023-06-09 16:24:24 +00:00
from . import model_base
2023-06-22 17:03:50 +00:00
from . import model_detection
2023-02-03 07:06:34 +00:00
2023-06-22 17:03:50 +00:00
from . import sd1_clip
from . import sd2_clip
2023-06-25 05:40:38 +00:00
from . import sdxl_clip
2023-06-09 16:24:24 +00:00
2023-08-28 18:49:18 +00:00
import comfy . model_patcher
2023-08-25 21:11:51 +00:00
import comfy . lora
2023-08-25 21:25:39 +00:00
import comfy . t2i_adapter . adapter
2023-08-30 03:58:32 +00:00
import comfy . supported_models_base
2023-08-25 21:11:51 +00:00
2023-06-22 17:03:50 +00:00
def load_model_weights ( model , sd ) :
2023-01-03 06:53:32 +00:00
m , u = model . load_state_dict ( sd , strict = False )
2023-06-22 17:03:50 +00:00
m = set ( m )
unexpected_keys = set ( u )
2023-01-03 06:53:32 +00:00
k = list ( sd . keys ( ) )
for x in k :
2023-06-22 17:03:50 +00:00
if x not in unexpected_keys :
w = sd . pop ( x )
del w
if len ( m ) > 0 :
print ( " missing " , m )
return model
def load_clip_weights ( model , sd ) :
k = list ( sd . keys ( ) )
for x in k :
2023-01-03 06:53:32 +00:00
if x . startswith ( " cond_stage_model.transformer. " ) and not x . startswith ( " cond_stage_model.transformer.text_model. " ) :
y = x . replace ( " cond_stage_model.transformer. " , " cond_stage_model.transformer.text_model. " )
sd [ y ] = sd . pop ( x )
2023-01-28 07:14:22 +00:00
if ' cond_stage_model.transformer.text_model.embeddings.position_ids ' in sd :
ids = sd [ ' cond_stage_model.transformer.text_model.embeddings.position_ids ' ]
if ids . dtype == torch . float32 :
sd [ ' cond_stage_model.transformer.text_model.embeddings.position_ids ' ] = ids . round ( )
2023-01-28 05:19:33 +00:00
2023-08-25 21:25:39 +00:00
sd = comfy . utils . transformers_convert ( sd , " cond_stage_model.model. " , " cond_stage_model.transformer.text_model. " , 24 )
2023-06-22 17:03:50 +00:00
return load_model_weights ( model , sd )
2023-01-03 06:53:32 +00:00
2023-08-17 05:06:34 +00:00
2023-06-30 03:40:02 +00:00
def load_lora_for_models ( model , clip , lora , strength_model , strength_clip ) :
2023-08-25 21:11:51 +00:00
key_map = comfy . lora . model_lora_keys_unet ( model . model )
key_map = comfy . lora . model_lora_keys_clip ( clip . cond_stage_model , key_map )
loaded = comfy . lora . load_lora ( lora , key_map )
2023-02-03 07:06:34 +00:00
new_modelpatcher = model . clone ( )
k = new_modelpatcher . add_patches ( loaded , strength_model )
new_clip = clip . clone ( )
k1 = new_clip . add_patches ( loaded , strength_clip )
k = set ( k )
k1 = set ( k1 )
for x in loaded :
if ( x not in k ) and ( x not in k1 ) :
print ( " NOT LOADED " , x )
return ( new_modelpatcher , new_clip )
2023-01-03 06:53:32 +00:00
class CLIP :
2023-06-22 17:03:50 +00:00
def __init__ ( self , target = None , embedding_directory = None , no_init = False ) :
2023-02-03 07:06:34 +00:00
if no_init :
return
2023-07-03 20:09:02 +00:00
params = target . params . copy ( )
2023-06-22 17:03:50 +00:00
clip = target . clip
tokenizer = target . tokenizer
2023-01-29 23:46:44 +00:00
2023-07-01 17:22:51 +00:00
load_device = model_management . text_encoder_device ( )
offload_device = model_management . text_encoder_offload_device ( )
2023-08-28 19:08:45 +00:00
params [ ' device ' ] = offload_device
2023-08-24 01:45:00 +00:00
if model_management . should_use_fp16 ( load_device , prioritize_performance = False ) :
2023-08-24 01:01:15 +00:00
params [ ' dtype ' ] = torch . float16
else :
params [ ' dtype ' ] = torch . float32
self . cond_stage_model = clip ( * * ( params ) )
2023-06-15 19:21:37 +00:00
2023-02-05 20:49:03 +00:00
self . tokenizer = tokenizer ( embedding_directory = embedding_directory )
2023-08-28 18:49:18 +00:00
self . patcher = comfy . model_patcher . ModelPatcher ( self . cond_stage_model , load_device = load_device , offload_device = offload_device )
2023-03-06 16:34:02 +00:00
self . layer_idx = None
2023-02-03 07:06:34 +00:00
def clone ( self ) :
n = CLIP ( no_init = True )
n . patcher = self . patcher . clone ( )
n . cond_stage_model = self . cond_stage_model
n . tokenizer = self . tokenizer
2023-03-03 18:04:36 +00:00
n . layer_idx = self . layer_idx
2023-02-03 07:06:34 +00:00
return n
2023-07-14 06:37:30 +00:00
def add_patches ( self , patches , strength_patch = 1.0 , strength_model = 1.0 ) :
return self . patcher . add_patches ( patches , strength_patch , strength_model )
2023-01-03 06:53:32 +00:00
2023-02-05 20:20:18 +00:00
def clip_layer ( self , layer_idx ) :
2023-03-03 18:04:36 +00:00
self . layer_idx = layer_idx
2023-02-05 20:20:18 +00:00
2023-04-14 19:16:55 +00:00
def tokenize ( self , text , return_word_ids = False ) :
return self . tokenizer . tokenize_with_weights ( text , return_word_ids )
2023-04-13 20:06:50 +00:00
2023-04-19 13:36:19 +00:00
def encode_from_tokens ( self , tokens , return_pooled = False ) :
2023-03-06 16:34:02 +00:00
if self . layer_idx is not None :
self . cond_stage_model . clip_layer ( self . layer_idx )
2023-07-15 05:10:33 +00:00
else :
self . cond_stage_model . reset_clip_layer ( )
2023-07-01 17:22:51 +00:00
2023-08-17 14:58:59 +00:00
self . load_model ( )
2023-07-01 17:22:51 +00:00
cond , pooled = self . cond_stage_model . encode_token_weights ( tokens )
2023-04-19 13:36:19 +00:00
if return_pooled :
2023-07-01 17:22:51 +00:00
return cond , pooled
return cond
2023-01-03 06:53:32 +00:00
2023-04-15 22:46:58 +00:00
def encode ( self , text ) :
2023-04-15 22:55:17 +00:00
tokens = self . tokenize ( text )
2023-04-15 22:46:58 +00:00
return self . encode_from_tokens ( tokens )
2023-06-25 05:40:38 +00:00
def load_sd ( self , sd ) :
return self . cond_stage_model . load_sd ( sd )
2023-06-22 17:03:50 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . cond_stage_model . state_dict ( )
2023-08-17 14:58:59 +00:00
def load_model ( self ) :
model_management . load_model_gpu ( self . patcher )
return self . patcher
2023-06-26 16:21:07 +00:00
2023-07-14 06:37:30 +00:00
def get_key_patches ( self ) :
return self . patcher . get_key_patches ( )
2023-01-03 06:53:32 +00:00
class VAE :
2023-10-17 18:51:51 +00:00
def __init__ ( self , sd = None , device = None , config = None ) :
if ' decoder.up_blocks.0.resnets.0.norm1.weight ' in sd . keys ( ) : #diffusers format
sd = diffusers_convert . convert_vae_state_dict ( sd )
2023-01-03 06:53:32 +00:00
if config is None :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
2023-10-17 18:51:51 +00:00
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
2023-01-03 06:53:32 +00:00
else :
2023-05-28 06:02:09 +00:00
self . first_stage_model = AutoencoderKL ( * * ( config [ ' params ' ] ) )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . eval ( )
2023-10-17 18:51:51 +00:00
m , u = self . first_stage_model . load_state_dict ( sd , strict = False )
if len ( m ) > 0 :
print ( " Missing VAE keys " , m )
if len ( u ) > 0 :
print ( " Leftover VAE keys " , u )
2023-05-28 06:02:09 +00:00
2023-03-06 15:50:50 +00:00
if device is None :
2023-07-01 19:22:40 +00:00
device = model_management . vae_device ( )
2023-01-03 06:53:32 +00:00
self . device = device
2023-07-01 19:22:40 +00:00
self . offload_device = model_management . vae_offload_device ( )
2023-07-06 22:04:28 +00:00
self . vae_dtype = model_management . vae_dtype ( )
self . first_stage_model . to ( self . vae_dtype )
2023-01-03 06:53:32 +00:00
2023-03-22 18:49:00 +00:00
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-08-25 21:25:39 +00:00
steps = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-04-24 10:55:44 +00:00
2023-07-06 22:04:28 +00:00
decode_fn = lambda a : ( self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) + 1.0 ) . float ( )
2023-03-22 18:49:00 +00:00
output = torch . clamp ( (
2023-08-25 21:25:39 +00:00
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = 8 , pbar = pbar ) )
2023-03-22 18:49:00 +00:00
/ 3.0 ) / 2.0 , min = 0.0 , max = 1.0 )
return output
2023-06-12 03:25:39 +00:00
def encode_tiled_ ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
2023-08-25 21:25:39 +00:00
steps = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-06-12 03:25:39 +00:00
2023-10-17 18:51:51 +00:00
encode_fn = lambda a : self . first_stage_model . encode ( ( 2. * a - 1. ) . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
2023-08-25 21:25:39 +00:00
samples = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
2023-06-12 03:25:39 +00:00
samples / = 3.0
return samples
2023-03-22 18:49:00 +00:00
def decode ( self , samples_in ) :
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-03-22 18:49:00 +00:00
try :
2023-08-19 16:13:13 +00:00
memory_used = ( 2562 * samples_in . shape [ 2 ] * samples_in . shape [ 3 ] * 64 ) * 1.7
2023-08-17 05:06:34 +00:00
model_management . free_memory ( memory_used , self . device )
2023-03-29 06:24:37 +00:00
free_memory = model_management . get_free_memory ( self . device )
2023-08-17 05:06:34 +00:00
batch_number = int ( free_memory / memory_used )
2023-03-29 06:24:37 +00:00
batch_number = max ( 1 , batch_number )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * 8 ) , round ( samples_in . shape [ 3 ] * 8 ) ) , device = " cpu " )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
2023-07-06 22:04:28 +00:00
samples = samples_in [ x : x + batch_number ] . to ( self . vae_dtype ) . to ( self . device )
2023-10-04 13:40:59 +00:00
pixel_samples [ x : x + batch_number ] = torch . clamp ( ( self . first_stage_model . decode ( samples ) . cpu ( ) . float ( ) + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2023-03-22 18:49:00 +00:00
except model_management . OOM_EXCEPTION as e :
print ( " Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding. " )
pixel_samples = self . decode_tiled_ ( samples_in )
2023-07-01 19:22:40 +00:00
self . first_stage_model = self . first_stage_model . to ( self . offload_device )
2023-01-03 06:53:32 +00:00
pixel_samples = pixel_samples . cpu ( ) . movedim ( 1 , - 1 )
return pixel_samples
2023-03-22 07:29:09 +00:00
def decode_tiled ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-02-24 07:10:10 +00:00
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-03-22 18:49:00 +00:00
output = self . decode_tiled_ ( samples , tile_x , tile_y , overlap )
2023-07-01 19:22:40 +00:00
self . first_stage_model = self . first_stage_model . to ( self . offload_device )
2023-02-24 07:10:10 +00:00
return output . movedim ( 1 , - 1 )
2023-01-03 06:53:32 +00:00
def encode ( self , pixel_samples ) :
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
2023-08-19 16:13:13 +00:00
memory_used = ( 2078 * pixel_samples . shape [ 2 ] * pixel_samples . shape [ 3 ] ) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
2023-08-17 05:06:34 +00:00
model_management . free_memory ( memory_used , self . device )
2023-06-12 04:21:50 +00:00
free_memory = model_management . get_free_memory ( self . device )
2023-08-17 05:06:34 +00:00
batch_number = int ( free_memory / memory_used )
2023-06-12 04:21:50 +00:00
batch_number = max ( 1 , batch_number )
2023-06-12 03:25:39 +00:00
samples = torch . empty ( ( pixel_samples . shape [ 0 ] , 4 , round ( pixel_samples . shape [ 2 ] / / 8 ) , round ( pixel_samples . shape [ 3 ] / / 8 ) ) , device = " cpu " )
for x in range ( 0 , pixel_samples . shape [ 0 ] , batch_number ) :
2023-07-06 22:04:28 +00:00
pixels_in = ( 2. * pixel_samples [ x : x + batch_number ] - 1. ) . to ( self . vae_dtype ) . to ( self . device )
2023-10-17 18:51:51 +00:00
samples [ x : x + batch_number ] = self . first_stage_model . encode ( pixels_in ) . cpu ( ) . float ( )
2023-06-12 04:21:50 +00:00
2023-06-12 03:25:39 +00:00
except model_management . OOM_EXCEPTION as e :
print ( " Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding. " )
samples = self . encode_tiled_ ( pixel_samples )
2023-07-01 19:22:40 +00:00
self . first_stage_model = self . first_stage_model . to ( self . offload_device )
2023-01-03 06:53:32 +00:00
return samples
2023-03-11 20:28:15 +00:00
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
2023-07-01 19:22:40 +00:00
self . first_stage_model = self . first_stage_model . to ( self . offload_device )
2023-03-11 20:28:15 +00:00
return samples
2023-02-25 19:57:28 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . first_stage_model . state_dict ( )
2023-03-05 23:39:25 +00:00
class StyleModel :
def __init__ ( self , model , device = " cpu " ) :
self . model = model
def get_cond ( self , input ) :
return self . model ( input . last_hidden_state )
def load_style_model ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
model_data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-03-05 23:39:25 +00:00
keys = model_data . keys ( )
if " style_embedding " in keys :
2023-08-25 21:25:39 +00:00
model = comfy . t2i_adapter . adapter . StyleAdapter ( width = 1024 , context_dim = 768 , num_head = 8 , n_layes = 3 , num_token = 8 )
2023-03-05 23:39:25 +00:00
else :
raise Exception ( " invalid style model {} " . format ( ckpt_path ) )
model . load_state_dict ( model_data )
return StyleModel ( model )
2023-06-25 05:40:38 +00:00
def load_clip ( ckpt_paths , embedding_directory = None ) :
clip_data = [ ]
for p in ckpt_paths :
2023-08-25 21:25:39 +00:00
clip_data . append ( comfy . utils . load_torch_file ( p , safe_load = True ) )
2023-06-25 05:40:38 +00:00
2023-06-24 17:56:46 +00:00
class EmptyClass :
pass
2023-06-25 05:40:38 +00:00
for i in range ( len ( clip_data ) ) :
if " transformer.resblocks.0.ln_1.weight " in clip_data [ i ] :
2023-08-25 21:25:39 +00:00
clip_data [ i ] = comfy . utils . transformers_convert ( clip_data [ i ] , " " , " text_model. " , 32 )
2023-06-25 05:40:38 +00:00
2023-06-24 17:56:46 +00:00
clip_target = EmptyClass ( )
clip_target . params = { }
2023-06-25 05:40:38 +00:00
if len ( clip_data ) == 1 :
if " text_model.encoder.layers.30.mlp.fc1.weight " in clip_data [ 0 ] :
clip_target . clip = sdxl_clip . SDXLRefinerClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
elif " text_model.encoder.layers.22.mlp.fc1.weight " in clip_data [ 0 ] :
clip_target . clip = sd2_clip . SD2ClipModel
clip_target . tokenizer = sd2_clip . SD2Tokenizer
else :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
2023-02-05 20:20:18 +00:00
else :
2023-06-25 05:40:38 +00:00
clip_target . clip = sdxl_clip . SDXLClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2023-06-24 17:56:46 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
2023-06-25 05:40:38 +00:00
for c in clip_data :
m , u = clip . load_sd ( c )
if len ( m ) > 0 :
print ( " clip missing: " , m )
if len ( u ) > 0 :
print ( " clip unexpected: " , u )
2023-02-05 20:20:18 +00:00
return clip
2023-01-03 06:53:32 +00:00
2023-04-19 13:36:19 +00:00
def load_gligen ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-04-19 13:36:19 +00:00
model = gligen . load_gligen ( data )
if model_management . should_use_fp16 ( ) :
model = model . half ( )
2023-08-28 18:49:18 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = model_management . unet_offload_device ( ) )
2023-04-19 13:36:19 +00:00
2023-06-09 16:24:24 +00:00
def load_checkpoint ( config_path = None , ckpt_path = None , output_vae = True , output_clip = True , embedding_directory = None , state_dict = None , config = None ) :
2023-06-23 06:14:12 +00:00
#TODO: this function is a mess and should be removed eventually
2023-06-09 16:24:24 +00:00
if config is None :
with open ( config_path , ' r ' ) as stream :
config = yaml . safe_load ( stream )
2023-01-03 06:53:32 +00:00
model_config_params = config [ ' model ' ] [ ' params ' ]
clip_config = model_config_params [ ' cond_stage_config ' ]
scale_factor = model_config_params [ ' scale_factor ' ]
vae_config = model_config_params [ ' first_stage_config ' ]
2023-03-14 01:12:48 +00:00
fp16 = False
if " unet_config " in model_config_params :
if " params " in model_config_params [ " unet_config " ] :
2023-06-09 16:24:24 +00:00
unet_config = model_config_params [ " unet_config " ] [ " params " ]
if " use_fp16 " in unet_config :
2023-10-13 18:35:21 +00:00
fp16 = unet_config . pop ( " use_fp16 " )
if fp16 :
unet_config [ " dtype " ] = torch . float16
2023-06-09 16:24:24 +00:00
noise_aug_config = None
if " noise_aug_config " in model_config_params :
noise_aug_config = model_config_params [ " noise_aug_config " ]
2023-07-17 05:22:12 +00:00
model_type = model_base . ModelType . EPS
2023-06-09 16:24:24 +00:00
if " parameterization " in model_config_params :
if model_config_params [ " parameterization " ] == " v " :
2023-07-17 05:22:12 +00:00
model_type = model_base . ModelType . V_PREDICTION
2023-03-14 01:12:48 +00:00
2023-01-03 06:53:32 +00:00
clip = None
vae = None
class WeightsLoader ( torch . nn . Module ) :
pass
2023-06-22 17:03:50 +00:00
if state_dict is None :
2023-08-25 21:25:39 +00:00
state_dict = comfy . utils . load_torch_file ( ckpt_path )
2023-01-03 06:53:32 +00:00
2023-06-23 06:14:12 +00:00
class EmptyClass :
pass
2023-08-30 03:58:32 +00:00
model_config = comfy . supported_models_base . BASE ( { } )
2023-06-23 06:14:12 +00:00
from . import latent_formats
model_config . latent_format = latent_formats . SD15 ( scale_factor = scale_factor )
2023-08-30 03:58:32 +00:00
model_config . unet_config = unet_config
2023-06-23 06:14:12 +00:00
2023-09-01 19:18:25 +00:00
if config [ ' model ' ] [ " target " ] . endswith ( " ImageEmbeddingConditionedLatentDiffusion " ) :
2023-07-17 05:22:12 +00:00
model = model_base . SD21UNCLIP ( model_config , noise_aug_config [ " params " ] , model_type = model_type )
2023-06-09 16:24:24 +00:00
else :
2023-07-17 05:22:12 +00:00
model = model_base . BaseModel ( model_config , model_type = model_type )
2023-06-09 16:24:24 +00:00
2023-09-01 19:18:25 +00:00
if config [ ' model ' ] [ " target " ] . endswith ( " LatentInpaintDiffusion " ) :
model . set_inpaint ( )
2023-03-14 01:12:48 +00:00
if fp16 :
model = model . half ( )
2023-07-01 17:22:51 +00:00
offload_device = model_management . unet_offload_device ( )
model = model . to ( offload_device )
2023-06-22 17:03:50 +00:00
model . load_model_weights ( state_dict , " model.diffusion_model. " )
if output_vae :
2023-10-17 18:51:51 +00:00
vae_sd = comfy . utils . state_dict_prefix_replace ( state_dict , { " first_stage_model. " : " " } , filter_keys = True )
vae = VAE ( sd = vae_sd , config = vae_config )
2023-06-22 17:03:50 +00:00
if output_clip :
w = WeightsLoader ( )
clip_target = EmptyClass ( )
2023-06-23 05:12:59 +00:00
clip_target . params = clip_config . get ( " params " , { } )
2023-06-22 17:03:50 +00:00
if clip_config [ " target " ] . endswith ( " FrozenOpenCLIPEmbedder " ) :
clip_target . clip = sd2_clip . SD2ClipModel
clip_target . tokenizer = sd2_clip . SD2Tokenizer
elif clip_config [ " target " ] . endswith ( " FrozenCLIPEmbedder " ) :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model
load_clip_weights ( w , state_dict )
2023-08-28 18:49:18 +00:00
return ( comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = offload_device ) , clip , vae )
2023-03-03 08:37:35 +00:00
2023-10-06 17:48:18 +00:00
def load_checkpoint_guess_config ( ckpt_path , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True ) :
2023-08-25 21:25:39 +00:00
sd = comfy . utils . load_torch_file ( ckpt_path )
2023-03-03 08:37:35 +00:00
sd_keys = sd . keys ( )
clip = None
2023-04-02 03:19:15 +00:00
clipvision = None
2023-03-03 08:37:35 +00:00
vae = None
2023-06-22 17:03:50 +00:00
model = None
2023-10-06 17:48:18 +00:00
model_patcher = None
2023-06-22 17:03:50 +00:00
clip_target = None
2023-03-03 08:37:35 +00:00
2023-08-25 21:25:39 +00:00
parameters = comfy . utils . calculate_parameters ( sd , " model.diffusion_model. " )
2023-10-13 18:35:21 +00:00
unet_dtype = model_management . unet_dtype ( model_params = parameters )
2023-03-03 16:07:10 +00:00
2023-03-03 08:37:35 +00:00
class WeightsLoader ( torch . nn . Module ) :
pass
2023-10-13 18:35:21 +00:00
model_config = model_detection . model_config_from_unet ( sd , " model.diffusion_model. " , unet_dtype )
2023-06-22 17:03:50 +00:00
if model_config is None :
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( ckpt_path ) )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if model_config . clip_vision_prefix is not None :
2023-04-02 03:19:15 +00:00
if output_clipvision :
2023-06-23 05:08:05 +00:00
clipvision = clip_vision . load_clipvision_from_sd ( sd , model_config . clip_vision_prefix , True )
2023-03-03 08:37:35 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
2023-10-13 18:35:21 +00:00
inital_load_device = model_management . unet_inital_load_device ( parameters , unet_dtype )
2023-10-06 17:48:18 +00:00
offload_device = model_management . unet_offload_device ( )
model = model_config . get_model ( sd , " model.diffusion_model. " , device = inital_load_device )
model . load_model_weights ( sd , " model.diffusion_model. " )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if output_vae :
2023-10-17 18:51:51 +00:00
vae_sd = comfy . utils . state_dict_prefix_replace ( sd , { " first_stage_model. " : " " } , filter_keys = True )
vae = VAE ( sd = vae_sd )
2023-03-03 08:37:35 +00:00
2023-06-22 17:03:50 +00:00
if output_clip :
w = WeightsLoader ( )
clip_target = model_config . clip_target ( )
2023-10-18 23:48:36 +00:00
if clip_target is not None :
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model
sd = model_config . process_clip_state_dict ( sd )
load_model_weights ( w , sd )
2023-06-09 16:24:24 +00:00
2023-06-22 17:03:50 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
print ( " left over keys: " , left_over )
2023-06-14 16:48:02 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
model_patcher = comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = model_management . unet_offload_device ( ) , current_device = inital_load_device )
if inital_load_device != torch . device ( " cpu " ) :
print ( " loaded straight to GPU " )
model_management . load_model_gpu ( model_patcher )
2023-08-17 05:06:34 +00:00
return ( model_patcher , clip , vae , clipvision )
2023-06-26 16:21:07 +00:00
2023-07-05 21:34:45 +00:00
def load_unet ( unet_path ) : #load unet in diffusers format
2023-08-25 21:25:39 +00:00
sd = comfy . utils . load_torch_file ( unet_path )
parameters = comfy . utils . calculate_parameters ( sd )
2023-10-13 18:35:21 +00:00
unet_dtype = model_management . unet_dtype ( model_params = parameters )
2023-09-11 20:36:50 +00:00
if " input_blocks.0.0.weight " in sd : #ldm
2023-10-13 18:35:21 +00:00
model_config = model_detection . model_config_from_unet ( sd , " " , unet_dtype )
2023-09-11 20:36:50 +00:00
if model_config is None :
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( unet_path ) )
new_sd = sd
else : #diffusers
2023-10-13 18:35:21 +00:00
model_config = model_detection . model_config_from_diffusers_unet ( sd , unet_dtype )
2023-09-11 20:36:50 +00:00
if model_config is None :
print ( " ERROR UNSUPPORTED UNET " , unet_path )
return None
diffusers_keys = comfy . utils . unet_to_diffusers ( model_config . unet_config )
new_sd = { }
for k in diffusers_keys :
if k in sd :
new_sd [ diffusers_keys [ k ] ] = sd . pop ( k )
else :
print ( diffusers_keys [ k ] , k )
2023-07-22 02:58:16 +00:00
offload_device = model_management . unet_offload_device ( )
model = model_config . get_model ( new_sd , " " )
model = model . to ( offload_device )
model . load_model_weights ( new_sd , " " )
2023-08-28 18:49:18 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = offload_device )
2023-07-05 21:34:45 +00:00
2023-06-26 16:21:07 +00:00
def save_checkpoint ( output_path , model , clip , vae , metadata = None ) :
2023-08-17 14:58:59 +00:00
model_management . load_models_gpu ( [ model , clip . load_model ( ) ] )
sd = model . model . state_dict_for_saving ( clip . get_sd ( ) , vae . get_sd ( ) )
2023-08-25 21:25:39 +00:00
comfy . utils . save_torch_file ( sd , output_path , metadata = metadata )