2023-04-06 03:41:23 +00:00
import psutil
2024-03-10 15:37:08 +00:00
import logging
2023-04-06 03:41:23 +00:00
from enum import Enum
2023-05-05 04:19:35 +00:00
from comfy . cli_args import args
2023-08-26 15:52:07 +00:00
import comfy . utils
2023-06-02 19:05:25 +00:00
import torch
2023-08-17 05:06:34 +00:00
import sys
2023-02-08 08:17:54 +00:00
2023-04-06 03:41:23 +00:00
class VRAMState ( Enum ) :
2023-06-04 21:51:04 +00:00
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
2023-04-06 03:41:23 +00:00
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
2023-06-04 21:51:04 +00:00
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
2023-06-03 15:05:37 +00:00
class CPUState ( Enum ) :
GPU = 0
CPU = 1
MPS = 2
2023-02-08 16:37:10 +00:00
2023-04-06 03:41:23 +00:00
# Determine VRAM State
vram_state = VRAMState . NORMAL_VRAM
set_vram_to = VRAMState . NORMAL_VRAM
2023-06-03 15:05:37 +00:00
cpu_state = CPUState . GPU
2023-02-08 16:37:10 +00:00
2023-02-08 19:05:31 +00:00
total_vram = 0
2023-02-08 16:42:37 +00:00
2023-05-30 16:36:41 +00:00
lowvram_available = True
2023-04-07 01:11:30 +00:00
xpu_available = False
2023-02-08 16:37:10 +00:00
2023-12-17 21:59:21 +00:00
if args . deterministic :
2024-03-11 17:54:56 +00:00
logging . info ( " Using deterministic algorithms for pytorch " )
2023-12-17 21:59:21 +00:00
torch . use_deterministic_algorithms ( True , warn_only = True )
2023-04-28 18:28:57 +00:00
directml_enabled = False
2023-04-28 20:51:35 +00:00
if args . directml is not None :
2023-04-28 18:28:57 +00:00
import torch_directml
directml_enabled = True
2023-04-28 20:51:35 +00:00
device_index = args . directml
if device_index < 0 :
directml_device = torch_directml . device ( )
else :
directml_device = torch_directml . device ( device_index )
2024-03-11 17:54:56 +00:00
logging . info ( " Using directml with device: {} " . format ( torch_directml . device_name ( device_index ) ) )
2023-04-28 18:28:57 +00:00
# torch_directml.disable_tiled_resources(True)
2023-05-30 16:36:41 +00:00
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
2023-04-28 18:28:57 +00:00
2023-02-08 19:05:31 +00:00
try :
2023-06-02 19:05:25 +00:00
import intel_extension_for_pytorch as ipex
if torch . xpu . is_available ( ) :
xpu_available = True
2023-02-08 19:05:31 +00:00
except :
pass
2023-06-03 15:05:37 +00:00
try :
if torch . backends . mps . is_available ( ) :
cpu_state = CPUState . MPS
2023-07-12 02:06:34 +00:00
import torch . mps
2023-06-03 15:05:37 +00:00
except :
pass
if args . cpu :
cpu_state = CPUState . CPU
2023-09-03 01:22:10 +00:00
def is_intel_xpu ( ) :
global cpu_state
2023-06-02 19:05:25 +00:00
global xpu_available
2023-09-03 01:22:10 +00:00
if cpu_state == CPUState . GPU :
if xpu_available :
return True
return False
def get_torch_device ( ) :
2023-06-02 19:05:25 +00:00
global directml_enabled
2023-06-03 15:05:37 +00:00
global cpu_state
2023-06-02 19:05:25 +00:00
if directml_enabled :
global directml_device
return directml_device
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . MPS :
2023-06-02 19:05:25 +00:00
return torch . device ( " mps " )
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . CPU :
2023-06-02 19:05:25 +00:00
return torch . device ( " cpu " )
else :
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2023-06-02 19:05:25 +00:00
return torch . device ( " xpu " )
else :
return torch . device ( torch . cuda . current_device ( ) )
def get_total_memory ( dev = None , torch_total_too = False ) :
global directml_enabled
if dev is None :
dev = get_torch_device ( )
if hasattr ( dev , ' type ' ) and ( dev . type == ' cpu ' or dev . type == ' mps ' ) :
mem_total = psutil . virtual_memory ( ) . total
mem_total_torch = mem_total
else :
if directml_enabled :
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
stats = torch . xpu . memory_stats ( dev )
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
2023-06-02 19:05:25 +00:00
mem_total = torch . xpu . get_device_properties ( dev ) . total_memory
2023-08-17 10:12:17 +00:00
mem_total_torch = mem_reserved
2023-06-02 19:05:25 +00:00
else :
stats = torch . cuda . memory_stats ( dev )
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
_ , mem_total_cuda = torch . cuda . mem_get_info ( dev )
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too :
return ( mem_total , mem_total_torch )
else :
return mem_total
total_vram = get_total_memory ( get_torch_device ( ) ) / ( 1024 * 1024 )
total_ram = psutil . virtual_memory ( ) . total / ( 1024 * 1024 )
2024-03-11 17:54:56 +00:00
logging . info ( " Total VRAM {:0.0f} MB, total RAM {:0.0f} MB " . format ( total_vram , total_ram ) )
2023-06-02 19:05:25 +00:00
if not args . normalvram and not args . cpu :
if lowvram_available and total_vram < = 4096 :
2024-03-10 15:37:08 +00:00
logging . warning ( " Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don ' t want this use: --normalvram " )
2023-06-02 19:05:25 +00:00
set_vram_to = VRAMState . LOW_VRAM
2023-03-22 18:49:00 +00:00
try :
OOM_EXCEPTION = torch . cuda . OutOfMemoryError
except :
OOM_EXCEPTION = Exception
2023-04-09 05:31:47 +00:00
XFORMERS_VERSION = " "
XFORMERS_ENABLED_VAE = True
2023-04-06 03:41:23 +00:00
if args . disable_xformers :
XFORMERS_IS_AVAILABLE = False
2023-03-13 15:36:48 +00:00
else :
try :
import xformers
import xformers . ops
2023-04-06 03:41:23 +00:00
XFORMERS_IS_AVAILABLE = True
2023-11-13 17:27:44 +00:00
try :
XFORMERS_IS_AVAILABLE = xformers . _has_cpp_library
except :
pass
2023-04-09 05:31:47 +00:00
try :
XFORMERS_VERSION = xformers . version . __version__
2024-03-11 17:54:56 +00:00
logging . info ( " xformers version: {} " . format ( XFORMERS_VERSION ) )
2023-04-09 05:31:47 +00:00
if XFORMERS_VERSION . startswith ( " 0.0.18 " ) :
2024-03-10 15:37:08 +00:00
logging . warning ( " \n WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images. " )
logging . warning ( " Please downgrade or upgrade xformers to a different version. \n " )
2023-04-09 05:31:47 +00:00
XFORMERS_ENABLED_VAE = False
except :
pass
2023-03-13 15:36:48 +00:00
except :
2023-04-06 03:41:23 +00:00
XFORMERS_IS_AVAILABLE = False
2023-03-13 15:36:48 +00:00
2023-06-26 16:55:07 +00:00
def is_nvidia ( ) :
global cpu_state
if cpu_state == CPUState . GPU :
if torch . version . cuda :
return True
2023-09-03 01:22:10 +00:00
return False
2023-06-26 16:55:07 +00:00
2023-10-12 01:29:03 +00:00
ENABLE_PYTORCH_ATTENTION = False
if args . use_pytorch_cross_attention :
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
2023-08-28 03:06:19 +00:00
VAE_DTYPE = torch . float32
2023-06-26 16:55:07 +00:00
2023-08-28 03:06:19 +00:00
try :
if is_nvidia ( ) :
torch_version = torch . version . __version__
if int ( torch_version [ 0 ] ) > = 2 :
2023-10-12 01:29:03 +00:00
if ENABLE_PYTORCH_ATTENTION == False and args . use_split_cross_attention == False and args . use_quad_cross_attention == False :
2023-06-26 16:55:07 +00:00
ENABLE_PYTORCH_ATTENTION = True
2024-01-15 08:10:22 +00:00
if torch . cuda . is_bf16_supported ( ) and torch . cuda . get_device_properties ( torch . cuda . current_device ( ) ) . major > = 8 :
2023-08-28 03:06:19 +00:00
VAE_DTYPE = torch . bfloat16
2023-09-17 08:09:19 +00:00
if is_intel_xpu ( ) :
if args . use_split_cross_attention == False and args . use_quad_cross_attention == False :
ENABLE_PYTORCH_ATTENTION = True
2023-08-28 03:06:19 +00:00
except :
pass
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
VAE_DTYPE = torch . bfloat16
2023-12-30 10:38:21 +00:00
if args . cpu_vae :
VAE_DTYPE = torch . float32
2023-08-28 03:06:19 +00:00
if args . fp16_vae :
VAE_DTYPE = torch . float16
elif args . bf16_vae :
VAE_DTYPE = torch . bfloat16
elif args . fp32_vae :
VAE_DTYPE = torch . float32
2023-06-26 16:55:07 +00:00
2023-04-06 03:41:23 +00:00
if ENABLE_PYTORCH_ATTENTION :
2023-03-13 16:25:19 +00:00
torch . backends . cuda . enable_math_sdp ( True )
torch . backends . cuda . enable_flash_sdp ( True )
torch . backends . cuda . enable_mem_efficient_sdp ( True )
2023-03-12 19:44:16 +00:00
2023-04-06 03:41:23 +00:00
if args . lowvram :
set_vram_to = VRAMState . LOW_VRAM
2023-05-30 16:36:41 +00:00
lowvram_available = True
2023-04-06 03:41:23 +00:00
elif args . novram :
set_vram_to = VRAMState . NO_VRAM
2023-06-15 19:21:37 +00:00
elif args . highvram or args . gpu_only :
2023-04-06 03:41:23 +00:00
vram_state = VRAMState . HIGH_VRAM
2023-03-24 18:30:43 +00:00
2023-04-07 04:27:54 +00:00
FORCE_FP32 = False
2023-07-02 02:42:35 +00:00
FORCE_FP16 = False
2023-04-07 04:27:54 +00:00
if args . force_fp32 :
2024-03-11 17:54:56 +00:00
logging . info ( " Forcing FP32, if this improves things please report it. " )
2023-04-07 04:27:54 +00:00
FORCE_FP32 = True
2023-07-02 02:42:35 +00:00
if args . force_fp16 :
2024-03-11 17:54:56 +00:00
logging . info ( " Forcing FP16. " )
2023-07-02 02:42:35 +00:00
FORCE_FP16 = True
2023-05-30 16:36:41 +00:00
if lowvram_available :
2023-12-22 19:24:04 +00:00
if set_vram_to in ( VRAMState . LOW_VRAM , VRAMState . NO_VRAM ) :
vram_state = set_vram_to
2023-02-08 19:05:31 +00:00
2023-02-08 16:37:10 +00:00
2023-06-03 15:05:37 +00:00
if cpu_state != CPUState . GPU :
vram_state = VRAMState . DISABLED
2023-03-24 18:30:43 +00:00
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . MPS :
vram_state = VRAMState . SHARED
2023-02-08 16:37:10 +00:00
2024-03-11 17:54:56 +00:00
logging . info ( f " Set vram state to: { vram_state . name } " )
2023-02-08 16:37:10 +00:00
2023-08-17 07:12:37 +00:00
DISABLE_SMART_MEMORY = args . disable_smart_memory
if DISABLE_SMART_MEMORY :
2024-03-11 17:54:56 +00:00
logging . info ( " Disabling smart memory management " )
2023-06-03 15:05:37 +00:00
2023-05-13 21:11:27 +00:00
def get_torch_device_name ( device ) :
if hasattr ( device , ' type ' ) :
2023-06-02 19:05:25 +00:00
if device . type == " cuda " :
2023-07-17 19:18:58 +00:00
try :
allocator_backend = torch . cuda . get_allocator_backend ( )
except :
allocator_backend = " "
return " {} {} : {} " . format ( device , torch . cuda . get_device_name ( device ) , allocator_backend )
2023-06-02 19:05:25 +00:00
else :
return " {} " . format ( device . type )
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
return " {} {} " . format ( device , torch . xpu . get_device_name ( device ) )
2023-06-02 19:05:25 +00:00
else :
return " CUDA {} : {} " . format ( device , torch . cuda . get_device_name ( device ) )
2023-05-13 21:11:27 +00:00
try :
2024-03-11 17:54:56 +00:00
logging . info ( " Device: {} " . format ( get_torch_device_name ( get_torch_device ( ) ) ) )
2023-05-13 21:11:27 +00:00
except :
2024-03-10 15:37:08 +00:00
logging . warning ( " Could not pick default device. " )
2023-05-13 21:11:27 +00:00
2024-03-11 17:54:56 +00:00
logging . info ( " VAE dtype: {} " . format ( VAE_DTYPE ) )
2023-02-08 08:17:54 +00:00
2023-08-17 05:06:34 +00:00
current_loaded_models = [ ]
2023-02-08 08:17:54 +00:00
2023-12-29 02:41:10 +00:00
def module_size ( module ) :
module_mem = 0
sd = module . state_dict ( )
for k in sd :
t = sd [ k ]
module_mem + = t . nelement ( ) * t . element_size ( )
return module_mem
2023-08-17 05:06:34 +00:00
class LoadedModel :
def __init__ ( self , model ) :
self . model = model
self . device = model . load_device
2024-03-20 05:29:26 +00:00
self . weights_loaded = False
2023-02-08 16:37:10 +00:00
2023-08-17 05:06:34 +00:00
def model_memory ( self ) :
return self . model . model_size ( )
2023-02-08 16:37:10 +00:00
2023-08-17 05:06:34 +00:00
def model_memory_required ( self , device ) :
if device == self . model . current_device :
return 0
else :
return self . model_memory ( )
2023-02-18 02:14:07 +00:00
2023-08-17 05:06:34 +00:00
def model_load ( self , lowvram_model_memory = 0 ) :
2024-03-13 23:04:41 +00:00
patch_model_to = self . device
2023-02-08 16:37:10 +00:00
2023-08-17 05:06:34 +00:00
self . model . model_patches_to ( self . device )
self . model . model_patches_to ( self . model . model_dtype ( ) )
2023-02-18 02:14:07 +00:00
2024-03-20 05:29:26 +00:00
load_weights = not self . weights_loaded
2023-08-17 05:06:34 +00:00
try :
2024-03-20 05:29:26 +00:00
if lowvram_model_memory > 0 and load_weights :
2024-03-13 23:04:41 +00:00
self . real_model = self . model . patch_model_lowvram ( device_to = patch_model_to , lowvram_model_memory = lowvram_model_memory )
else :
2024-03-20 05:29:26 +00:00
self . real_model = self . model . patch_model ( device_to = patch_model_to , patch_weights = load_weights )
2023-08-17 05:06:34 +00:00
except Exception as e :
self . model . unpatch_model ( self . model . offload_device )
self . model_unload ( )
raise e
2023-02-08 08:17:54 +00:00
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) and not args . disable_ipex_optimize :
2023-08-20 04:35:22 +00:00
self . real_model = torch . xpu . optimize ( self . real_model . eval ( ) , inplace = True , auto_kernel_selection = True , graph_mode = True )
2023-08-17 10:12:17 +00:00
2024-03-20 05:29:26 +00:00
self . weights_loaded = True
2023-08-17 05:06:34 +00:00
return self . real_model
2023-02-08 16:37:10 +00:00
2024-03-20 05:29:26 +00:00
def model_unload ( self , unpatch_weights = True ) :
self . model . unpatch_model ( self . model . offload_device , unpatch_weights = unpatch_weights )
2023-08-17 05:06:34 +00:00
self . model . model_patches_to ( self . model . offload_device )
2024-03-20 05:29:26 +00:00
self . weights_loaded = self . weights_loaded and not unpatch_weights
2023-05-30 16:36:41 +00:00
2023-08-17 05:06:34 +00:00
def __eq__ ( self , other ) :
return self . model is other . model
2023-07-15 17:24:05 +00:00
2023-08-17 05:06:34 +00:00
def minimum_inference_memory ( ) :
return ( 1024 * 1024 * 1024 )
2024-03-20 17:53:45 +00:00
def unload_model_clones ( model , unload_weights_only = True , force_unload = True ) :
2023-08-17 05:06:34 +00:00
to_unload = [ ]
for i in range ( len ( current_loaded_models ) ) :
if model . is_clone ( current_loaded_models [ i ] . model ) :
to_unload = [ i ] + to_unload
2024-03-20 05:29:26 +00:00
if len ( to_unload ) == 0 :
2024-03-20 17:53:45 +00:00
return None
2024-03-20 05:29:26 +00:00
same_weights = 0
2023-08-17 05:06:34 +00:00
for i in to_unload :
2024-03-20 05:29:26 +00:00
if model . clone_has_same_weights ( current_loaded_models [ i ] . model ) :
same_weights + = 1
if same_weights == len ( to_unload ) :
unload_weight = False
else :
unload_weight = True
2024-03-20 17:53:45 +00:00
if not force_unload :
if unload_weights_only and unload_weight == False :
return None
2024-03-20 05:29:26 +00:00
for i in to_unload :
logging . debug ( " unload clone {} {} " . format ( i , unload_weight ) )
current_loaded_models . pop ( i ) . model_unload ( unpatch_weights = unload_weight )
2024-03-20 17:53:45 +00:00
return unload_weight
2023-08-17 05:06:34 +00:00
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
unloaded_model = False
for i in range ( len ( current_loaded_models ) - 1 , - 1 , - 1 ) :
2023-08-24 23:39:18 +00:00
if not DISABLE_SMART_MEMORY :
if get_free_memory ( device ) > memory_required :
break
2023-08-17 05:06:34 +00:00
shift_model = current_loaded_models [ i ]
if shift_model . device == device :
if shift_model not in keep_loaded :
2023-08-24 23:39:18 +00:00
m = current_loaded_models . pop ( i )
m . model_unload ( )
del m
2023-08-17 05:06:34 +00:00
unloaded_model = True
if unloaded_model :
soft_empty_cache ( )
2023-10-22 17:53:59 +00:00
else :
if vram_state != VRAMState . HIGH_VRAM :
mem_free_total , mem_free_torch = get_free_memory ( device , torch_free_too = True )
if mem_free_torch > mem_free_total * 0.25 :
soft_empty_cache ( )
2023-08-17 05:06:34 +00:00
def load_models_gpu ( models , memory_required = 0 ) :
2023-02-17 20:45:29 +00:00
global vram_state
2023-08-17 05:06:34 +00:00
inference_memory = minimum_inference_memory ( )
extra_mem = max ( inference_memory , memory_required )
models_to_load = [ ]
models_already_loaded = [ ]
for x in models :
loaded_model = LoadedModel ( x )
if loaded_model in current_loaded_models :
index = current_loaded_models . index ( loaded_model )
current_loaded_models . insert ( 0 , current_loaded_models . pop ( index ) )
models_already_loaded . append ( loaded_model )
else :
2023-10-12 00:35:50 +00:00
if hasattr ( x , " model " ) :
2024-03-11 17:54:56 +00:00
logging . info ( f " Requested to load { x . model . __class__ . __name__ } " )
2023-08-17 05:06:34 +00:00
models_to_load . append ( loaded_model )
if len ( models_to_load ) == 0 :
devs = set ( map ( lambda a : a . device , models_already_loaded ) )
for d in devs :
if d != torch . device ( " cpu " ) :
free_memory ( extra_mem , d , models_already_loaded )
2023-02-17 20:45:29 +00:00
return
2024-03-11 17:54:56 +00:00
logging . info ( f " Loading { len ( models_to_load ) } new model { ' s ' if len ( models_to_load ) > 1 else ' ' } " )
2023-04-19 13:36:19 +00:00
2023-08-17 05:06:34 +00:00
total_memory_required = { }
for loaded_model in models_to_load :
2024-03-20 17:53:45 +00:00
unload_model_clones ( loaded_model . model , unload_weights_only = True , force_unload = False ) #unload clones where the weights are different
2023-08-17 05:06:34 +00:00
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
2023-02-16 15:38:08 +00:00
2023-08-17 05:06:34 +00:00
for device in total_memory_required :
if device != torch . device ( " cpu " ) :
free_memory ( total_memory_required [ device ] * 1.3 + extra_mem , device , models_already_loaded )
2023-02-16 15:38:08 +00:00
2024-03-20 05:29:26 +00:00
for loaded_model in models_to_load :
2024-03-20 17:53:45 +00:00
weights_unloaded = unload_model_clones ( loaded_model . model , unload_weights_only = False , force_unload = False ) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None :
loaded_model . weights_loaded = not weights_unloaded
2024-03-20 05:29:26 +00:00
2023-08-17 05:06:34 +00:00
for loaded_model in models_to_load :
model = loaded_model . model
torch_dev = model . load_device
if is_device_cpu ( torch_dev ) :
vram_set_state = VRAMState . DISABLED
else :
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and ( vram_set_state == VRAMState . LOW_VRAM or vram_set_state == VRAMState . NORMAL_VRAM ) :
model_size = loaded_model . model_memory_required ( torch_dev )
current_free_mem = get_free_memory ( torch_dev )
2023-12-22 19:24:04 +00:00
lowvram_model_memory = int ( max ( 64 * ( 1024 * 1024 ) , ( current_free_mem - 1024 * ( 1024 * 1024 ) ) / 1.3 ) )
2023-08-17 05:06:34 +00:00
if model_size > ( current_free_mem - inference_memory ) : #only switch to lowvram if really necessary
vram_set_state = VRAMState . LOW_VRAM
else :
lowvram_model_memory = 0
2023-02-08 19:05:31 +00:00
2023-08-17 05:06:34 +00:00
if vram_set_state == VRAMState . NO_VRAM :
2023-12-22 19:24:04 +00:00
lowvram_model_memory = 64 * 1024 * 1024
2023-02-17 20:45:29 +00:00
2023-08-17 05:06:34 +00:00
cur_loaded_model = loaded_model . model_load ( lowvram_model_memory )
current_loaded_models . insert ( 0 , loaded_model )
return
def load_model_gpu ( model ) :
return load_models_gpu ( [ model ] )
def cleanup_models ( ) :
to_delete = [ ]
for i in range ( len ( current_loaded_models ) ) :
if sys . getrefcount ( current_loaded_models [ i ] . model ) < = 2 :
to_delete = [ i ] + to_delete
for i in to_delete :
x = current_loaded_models . pop ( i )
x . model_unload ( )
del x
2023-02-17 20:45:29 +00:00
2023-08-24 21:20:54 +00:00
def dtype_size ( dtype ) :
dtype_size = 4
if dtype == torch . float16 or dtype == torch . bfloat16 :
dtype_size = 2
2023-12-04 16:52:06 +00:00
elif dtype == torch . float32 :
dtype_size = 4
else :
try :
dtype_size = dtype . itemsize
except : #Old pytorch doesn't have .itemsize
pass
2023-08-24 21:20:54 +00:00
return dtype_size
2023-07-01 17:22:51 +00:00
def unet_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if vram_state == VRAMState . HIGH_VRAM :
2023-07-01 17:22:51 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-08-17 05:06:34 +00:00
def unet_inital_load_device ( parameters , dtype ) :
torch_dev = get_torch_device ( )
if vram_state == VRAMState . HIGH_VRAM :
return torch_dev
cpu_dev = torch . device ( " cpu " )
2023-08-20 08:00:53 +00:00
if DISABLE_SMART_MEMORY :
return cpu_dev
2023-08-24 21:20:54 +00:00
model_size = dtype_size ( dtype ) * parameters
2023-08-17 05:06:34 +00:00
mem_dev = get_free_memory ( torch_dev )
mem_cpu = get_free_memory ( cpu_dev )
if mem_dev > mem_cpu and model_size < mem_dev :
return torch_dev
else :
return cpu_dev
2024-02-16 15:55:08 +00:00
def unet_dtype ( device = None , model_params = 0 , supported_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ] ) :
2023-10-13 18:51:10 +00:00
if args . bf16_unet :
return torch . bfloat16
2023-12-11 23:36:29 +00:00
if args . fp16_unet :
return torch . float16
2023-12-04 16:10:00 +00:00
if args . fp8_e4m3fn_unet :
return torch . float8_e4m3fn
if args . fp8_e5m2_unet :
return torch . float8_e5m2
2024-02-04 18:23:43 +00:00
if should_use_fp16 ( device = device , model_params = model_params , manual_cast = True ) :
2024-02-16 15:55:08 +00:00
if torch . float16 in supported_dtypes :
return torch . float16
2024-02-17 13:13:17 +00:00
if should_use_bf16 ( device , model_params = model_params , manual_cast = True ) :
2024-02-16 15:55:08 +00:00
if torch . bfloat16 in supported_dtypes :
return torch . bfloat16
2023-10-13 18:35:21 +00:00
return torch . float32
2023-12-11 23:24:44 +00:00
# None means no manual cast
2024-02-16 15:55:08 +00:00
def unet_manual_cast ( weight_dtype , inference_device , supported_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ] ) :
2023-12-11 23:24:44 +00:00
if weight_dtype == torch . float32 :
return None
2024-02-16 15:55:08 +00:00
fp16_supported = should_use_fp16 ( inference_device , prioritize_performance = False )
2023-12-11 23:24:44 +00:00
if fp16_supported and weight_dtype == torch . float16 :
return None
2024-02-16 15:55:08 +00:00
bf16_supported = should_use_bf16 ( inference_device )
if bf16_supported and weight_dtype == torch . bfloat16 :
return None
if fp16_supported and torch . float16 in supported_dtypes :
2023-12-11 23:24:44 +00:00
return torch . float16
2024-02-16 15:55:08 +00:00
elif bf16_supported and torch . bfloat16 in supported_dtypes :
return torch . bfloat16
2023-12-11 23:24:44 +00:00
else :
return torch . float32
2023-07-01 16:37:23 +00:00
def text_encoder_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-06-15 19:21:37 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 16:37:23 +00:00
def text_encoder_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-07-01 16:37:23 +00:00
return get_torch_device ( )
2023-07-01 18:38:51 +00:00
elif vram_state == VRAMState . HIGH_VRAM or vram_state == VRAMState . NORMAL_VRAM :
2023-09-14 16:16:07 +00:00
if is_intel_xpu ( ) :
return torch . device ( " cpu " )
2023-08-24 01:45:00 +00:00
if should_use_fp16 ( prioritize_performance = False ) :
2023-07-01 18:38:51 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 16:37:23 +00:00
else :
return torch . device ( " cpu " )
2023-11-17 07:56:59 +00:00
def text_encoder_dtype ( device = None ) :
if args . fp8_e4m3fn_text_enc :
return torch . float8_e4m3fn
elif args . fp8_e5m2_text_enc :
return torch . float8_e5m2
elif args . fp16_text_enc :
return torch . float16
elif args . fp32_text_enc :
return torch . float32
2023-12-11 04:00:54 +00:00
if is_device_cpu ( device ) :
return torch . float16
2024-02-02 15:02:49 +00:00
return torch . float16
2023-11-17 07:56:59 +00:00
2023-12-08 07:35:45 +00:00
def intermediate_device ( ) :
if args . gpu_only :
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 19:22:40 +00:00
def vae_device ( ) :
2023-12-30 10:38:21 +00:00
if args . cpu_vae :
return torch . device ( " cpu " )
2023-07-01 19:22:40 +00:00
return get_torch_device ( )
def vae_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-07-01 19:22:40 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-06 22:04:28 +00:00
def vae_dtype ( ) :
2023-08-28 03:06:19 +00:00
global VAE_DTYPE
return VAE_DTYPE
2023-07-06 22:04:28 +00:00
2023-03-06 15:50:50 +00:00
def get_autocast_device ( dev ) :
if hasattr ( dev , ' type ' ) :
return dev . type
return " cuda "
2023-02-17 20:45:29 +00:00
2023-12-04 16:10:00 +00:00
def supports_dtype ( device , dtype ) : #TODO
if dtype == torch . float32 :
return True
2023-12-11 23:24:44 +00:00
if is_device_cpu ( device ) :
2023-12-04 16:10:00 +00:00
return False
if dtype == torch . float16 :
return True
if dtype == torch . bfloat16 :
return True
return False
2023-12-22 19:24:04 +00:00
def device_supports_non_blocking ( device ) :
if is_device_mps ( device ) :
return False #pytorch bug? mps doesn't support non blocking
return True
2023-09-20 21:52:41 +00:00
def cast_to_device ( tensor , device , dtype , copy = False ) :
device_supports_cast = False
if tensor . dtype == torch . float32 or tensor . dtype == torch . float16 :
device_supports_cast = True
elif tensor . dtype == torch . bfloat16 :
if hasattr ( device , ' type ' ) and device . type . startswith ( " cuda " ) :
device_supports_cast = True
2023-09-23 04:11:27 +00:00
elif is_intel_xpu ( ) :
device_supports_cast = True
2023-09-20 21:52:41 +00:00
2023-12-22 19:24:04 +00:00
non_blocking = device_supports_non_blocking ( device )
2023-12-10 06:30:35 +00:00
2023-09-20 21:52:41 +00:00
if device_supports_cast :
if copy :
if tensor . device == device :
2023-12-10 06:30:35 +00:00
return tensor . to ( dtype , copy = copy , non_blocking = non_blocking )
return tensor . to ( device , copy = copy , non_blocking = non_blocking ) . to ( dtype , non_blocking = non_blocking )
2023-09-20 21:52:41 +00:00
else :
2023-12-10 06:30:35 +00:00
return tensor . to ( device , non_blocking = non_blocking ) . to ( dtype , non_blocking = non_blocking )
2023-09-20 21:52:41 +00:00
else :
2023-12-10 06:30:35 +00:00
return tensor . to ( device , dtype , copy = copy , non_blocking = non_blocking )
2023-04-05 02:22:02 +00:00
2023-03-12 19:44:16 +00:00
def xformers_enabled ( ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-06-03 15:05:37 +00:00
global cpu_state
if cpu_state != CPUState . GPU :
2023-03-12 19:44:16 +00:00
return False
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2023-04-28 18:28:57 +00:00
return False
if directml_enabled :
return False
2023-04-06 03:41:23 +00:00
return XFORMERS_IS_AVAILABLE
2023-03-12 19:44:16 +00:00
2023-04-05 02:22:02 +00:00
def xformers_enabled_vae ( ) :
enabled = xformers_enabled ( )
if not enabled :
return False
2023-04-09 05:31:47 +00:00
return XFORMERS_ENABLED_VAE
2023-04-05 02:22:02 +00:00
2023-03-13 16:25:19 +00:00
def pytorch_attention_enabled ( ) :
2023-05-06 23:58:54 +00:00
global ENABLE_PYTORCH_ATTENTION
2023-03-13 16:25:19 +00:00
return ENABLE_PYTORCH_ATTENTION
2023-05-06 23:58:54 +00:00
def pytorch_attention_flash_attention ( ) :
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION :
#TODO: more reliable way of checking for flash attention?
2023-06-26 16:55:07 +00:00
if is_nvidia ( ) : #pytorch flash attention only works on Nvidia
2023-05-06 23:58:54 +00:00
return True
return False
2023-03-03 08:27:33 +00:00
def get_free_memory ( dev = None , torch_free_too = False ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-03-03 08:27:33 +00:00
if dev is None :
2023-03-06 15:50:50 +00:00
dev = get_torch_device ( )
2023-03-03 08:27:33 +00:00
2023-03-24 12:04:50 +00:00
if hasattr ( dev , ' type ' ) and ( dev . type == ' cpu ' or dev . type == ' mps ' ) :
2023-03-03 08:27:33 +00:00
mem_free_total = psutil . virtual_memory ( ) . available
mem_free_torch = mem_free_total
else :
2023-04-28 18:28:57 +00:00
if directml_enabled :
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
stats = torch . xpu . memory_stats ( dev )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_allocated = stats [ ' allocated_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_torch = mem_reserved - mem_active
2023-08-20 04:35:22 +00:00
mem_free_total = torch . xpu . get_device_properties ( dev ) . total_memory - mem_allocated
2023-04-06 06:24:47 +00:00
else :
stats = torch . cuda . memory_stats ( dev )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( dev )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
2023-03-03 08:27:33 +00:00
if torch_free_too :
return ( mem_free_total , mem_free_torch )
else :
return mem_free_total
2023-02-08 19:05:31 +00:00
2023-03-03 16:07:10 +00:00
def cpu_mode ( ) :
2023-06-03 15:05:37 +00:00
global cpu_state
return cpu_state == CPUState . CPU
2023-03-03 16:07:10 +00:00
2023-03-24 12:04:50 +00:00
def mps_mode ( ) :
2023-06-03 15:05:37 +00:00
global cpu_state
return cpu_state == CPUState . MPS
2023-03-24 12:04:50 +00:00
2024-02-16 02:10:10 +00:00
def is_device_type ( device , type ) :
2023-07-01 17:22:51 +00:00
if hasattr ( device , ' type ' ) :
2024-02-16 02:10:10 +00:00
if ( device . type == type ) :
2023-07-04 06:09:02 +00:00
return True
return False
2024-02-16 02:10:10 +00:00
def is_device_cpu ( device ) :
return is_device_type ( device , ' cpu ' )
2023-07-04 06:09:02 +00:00
def is_device_mps ( device ) :
2024-02-16 02:10:10 +00:00
return is_device_type ( device , ' mps ' )
def is_device_cuda ( device ) :
return is_device_type ( device , ' cuda ' )
2023-07-01 17:22:51 +00:00
2024-02-04 18:23:43 +00:00
def should_use_fp16 ( device = None , model_params = 0 , prioritize_performance = True , manual_cast = False ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-08-24 01:38:28 +00:00
if device is not None :
if is_device_cpu ( device ) :
return False
2023-07-02 02:42:35 +00:00
if FORCE_FP16 :
return True
2024-02-19 17:00:48 +00:00
if device is not None :
2023-08-24 01:38:28 +00:00
if is_device_mps ( device ) :
2024-02-19 17:00:48 +00:00
return True
2023-07-01 16:37:23 +00:00
2023-04-07 04:27:54 +00:00
if FORCE_FP32 :
return False
2023-04-28 18:28:57 +00:00
if directml_enabled :
return False
2024-02-19 17:00:48 +00:00
if mps_mode ( ) :
return True
if cpu_mode ( ) :
return False
2023-03-03 16:07:10 +00:00
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2023-08-20 18:56:47 +00:00
return True
2024-02-05 01:53:35 +00:00
if torch . version . hip :
2023-03-03 16:07:10 +00:00
return True
2023-03-03 18:18:01 +00:00
props = torch . cuda . get_device_properties ( " cuda " )
2024-02-05 01:53:35 +00:00
if props . major > = 8 :
return True
2023-07-02 13:37:31 +00:00
if props . major < 6 :
return False
fp16_works = False
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
2024-03-02 22:16:31 +00:00
nvidia_10_series = [ " 1080 " , " 1070 " , " titan x " , " p3000 " , " p3200 " , " p4000 " , " p4200 " , " p5000 " , " p5200 " , " p6000 " , " 1060 " , " 1050 " , " p40 " , " p100 " , " p6 " , " p4 " ]
2023-07-02 13:37:31 +00:00
for x in nvidia_10_series :
if x in props . name . lower ( ) :
fp16_works = True
2024-02-04 18:23:43 +00:00
if fp16_works or manual_cast :
2023-07-02 13:37:31 +00:00
free_model_memory = ( get_free_memory ( ) * 0.9 - minimum_inference_memory ( ) )
2023-08-24 01:45:00 +00:00
if ( not prioritize_performance ) or model_params * 4 > free_model_memory :
2023-07-02 13:37:31 +00:00
return True
2023-03-03 16:07:10 +00:00
if props . major < 7 :
return False
2023-07-02 13:37:31 +00:00
#FP16 is just broken on these cards
2023-10-16 20:46:41 +00:00
nvidia_16_series = [ " 1660 " , " 1650 " , " 1630 " , " T500 " , " T550 " , " T600 " , " MX550 " , " MX450 " , " CMP 30HX " , " T2000 " , " T1000 " , " T1200 " ]
2023-03-03 16:07:10 +00:00
for x in nvidia_16_series :
if x in props . name :
return False
return True
2024-02-17 13:13:17 +00:00
def should_use_bf16 ( device = None , model_params = 0 , prioritize_performance = True , manual_cast = False ) :
if device is not None :
if is_device_cpu ( device ) : #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None : #TODO not sure about mps bf16 support
if is_device_mps ( device ) :
return False
2024-02-17 04:01:54 +00:00
if FORCE_FP32 :
return False
2024-02-17 13:13:17 +00:00
if directml_enabled :
return False
if cpu_mode ( ) or mps_mode ( ) :
return False
2024-02-16 15:55:08 +00:00
if is_intel_xpu ( ) :
return True
if device is None :
device = torch . device ( " cuda " )
props = torch . cuda . get_device_properties ( device )
if props . major > = 8 :
return True
2024-02-17 13:13:17 +00:00
bf16_works = torch . cuda . is_bf16_supported ( )
if bf16_works or manual_cast :
free_model_memory = ( get_free_memory ( ) * 0.9 - minimum_inference_memory ( ) )
if ( not prioritize_performance ) or model_params * 4 > free_model_memory :
return True
2024-02-16 15:55:08 +00:00
return False
2023-09-04 04:58:18 +00:00
def soft_empty_cache ( force = False ) :
2023-06-03 15:05:37 +00:00
global cpu_state
if cpu_state == CPUState . MPS :
2023-06-01 07:52:51 +00:00
torch . mps . empty_cache ( )
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-04-15 15:19:07 +00:00
torch . xpu . empty_cache ( )
elif torch . cuda . is_available ( ) :
2023-09-04 04:58:18 +00:00
if force or is_nvidia ( ) : #This seems to make things worse on ROCm so I only do it for cuda
2023-04-15 15:19:07 +00:00
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2023-12-23 09:25:06 +00:00
def unload_all_models ( ) :
free_memory ( 1e30 , get_torch_device ( ) )
2023-12-22 19:24:04 +00:00
def resolve_lowvram_weight ( weight , model , key ) : #TODO: remove
2023-08-26 15:52:07 +00:00
return weight
2023-03-02 19:42:03 +00:00
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException ( Exception ) :
pass
interrupt_processing_mutex = threading . RLock ( )
interrupt_processing = False
def interrupt_current_processing ( value = True ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
interrupt_processing = value
def processing_interrupted ( ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
return interrupt_processing
def throw_exception_if_processing_interrupted ( ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
if interrupt_processing :
interrupt_processing = False
raise InterruptProcessingException ( )