Refactor and improve model_management code related to free memory.

This commit is contained in:
comfyanonymous 2023-06-02 15:05:25 -04:00
parent 499641ebf1
commit 67892b5ac5
2 changed files with 68 additions and 69 deletions

View File

@ -1,6 +1,7 @@
import psutil import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 CPU = 0
@ -32,19 +33,60 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import torch
if directml_enabled:
pass #TODO
else:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available = True xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except: except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) pass
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096: if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
@ -52,8 +94,6 @@ try:
elif total_vram > total_ram * 1.1 and total_vram > 14336: elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
except:
pass
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
@ -128,29 +168,17 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda":
return "{} {}".format(device, torch.cuda.get_device_name(device))
else:
return "{}".format(device.type) return "{}".format(device.type)
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
print("Using device:", get_torch_device_name(get_torch_device())) print("Device:", get_torch_device_name(get_torch_device()))
except: except:
print("Could not pick default device.") print("Could not pick default device.")
@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled global directml_enabled

View File

@ -7,7 +7,6 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -284,9 +283,8 @@ class PromptServer():
@routes.get("/system_stats") @routes.get("/system_stats")
async def get_queue(request): async def get_queue(request):
device_index = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device = torch.device(device_index) device_name = comfy.model_management.get_torch_device_name(device)
device_name = comfy.model_management.get_torch_device_name(device_index)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = { system_stats = {