ComfyUI/comfy/model_management.py

328 lines
8.8 KiB
Python
Raw Normal View History

CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
2023-03-24 17:39:55 +00:00
MPS = 5
XPU = 6
accelerate_enabled = False
vram_state = NORMAL_VRAM
total_vram = 0
total_vram_available_mb = -1
import sys
import psutil
2023-03-27 04:48:09 +00:00
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
try:
import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv
2023-03-27 02:51:18 +00:00
if not forced_normal_vram and not forced_cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM
2023-02-28 01:04:18 +00:00
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM
except:
pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
2023-03-13 16:25:19 +00:00
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try:
import accelerate
accelerate_enabled = True
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
try:
if torch.backends.mps.is_available():
vram_state = MPS
except:
pass
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
vram_state = XPU
except:
pass
2023-03-27 04:48:09 +00:00
if forced_cpu:
vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state])
current_loaded_model = None
2023-02-16 15:38:08 +00:00
current_gpu_controlnets = []
model_accelerated = False
def unload_model():
global current_loaded_model
global model_accelerated
2023-02-16 15:38:08 +00:00
global current_gpu_controlnets
global vram_state
if current_loaded_model is not None:
if model_accelerated:
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False
#never unload models from GPU on high vram
if vram_state != HIGH_VRAM:
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != HIGH_VRAM:
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
def load_model_gpu(model):
global current_loaded_model
global vram_state
global model_accelerated
if model is current_loaded_model:
return
unload_model()
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
current_loaded_model = model
if vram_state == CPU:
pass
2023-03-24 12:04:50 +00:00
elif vram_state == MPS:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == XPU:
real_model.to("xpu")
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False
real_model.cuda()
else:
if vram_state == NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
2023-02-08 19:51:18 +00:00
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
model_accelerated = True
return current_loaded_model
2023-02-16 15:38:08 +00:00
def load_controlnet_gpu(models):
global current_gpu_controlnets
2023-02-17 20:45:29 +00:00
global vram_state
if vram_state == CPU:
return
2023-02-17 20:45:29 +00:00
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
2023-02-16 15:38:08 +00:00
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
device = get_torch_device()
2023-02-16 15:38:08 +00:00
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.to(device))
2023-02-16 15:38:08 +00:00
2023-02-17 20:45:29 +00:00
def load_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cuda()
return model
def unload_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cpu()
return model
def get_torch_device():
2023-03-24 12:04:50 +00:00
if vram_state == MPS:
return torch.device("mps")
if vram_state == XPU:
return torch.device("xpu")
if vram_state == CPU:
return torch.device("cpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
return "cuda"
2023-02-17 20:45:29 +00:00
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
2023-03-13 16:25:19 +00:00
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
2023-03-24 12:04:50 +00:00
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and (dev.type == 'xpu'):
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too:
return (mem_free_total, mem_free_torch)
else:
return mem_free_total
def maximum_batch_area():
global vram_state
if vram_state == NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode():
global vram_state
return vram_state == CPU
2023-03-24 12:04:50 +00:00
def mps_mode():
global vram_state
return vram_state == MPS
def xpu_mode():
global vram_state
return vram_state == XPU
def should_use_fp16():
if cpu_mode() or mps_mode() or xpu_mode():
return False #TODO ?
if torch.cuda.is_bf16_supported():
return True
2023-03-03 18:18:01 +00:00
props = torch.cuda.get_device_properties("cuda")
if props.major < 7:
return False
#FP32 is faster on those cards?
2023-03-21 20:57:35 +00:00
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
for x in nvidia_16_series:
if x in props.name:
return False
return True
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass
interrupt_processing_mutex = threading.RLock()
interrupt_processing = False
def interrupt_current_processing(value=True):
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
interrupt_processing = value
def processing_interrupted():
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
return interrupt_processing
def throw_exception_if_processing_interrupted():
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()