Refactor and improve model_management code related to free memory.
This commit is contained in:
parent
499641ebf1
commit
67892b5ac5
|
@ -1,6 +1,7 @@
|
||||||
import psutil
|
import psutil
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
CPU = 0
|
CPU = 0
|
||||||
|
@ -33,28 +34,67 @@ if args.directml is not None:
|
||||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import intel_extension_for_pytorch as ipex
|
||||||
if directml_enabled:
|
if torch.xpu.is_available():
|
||||||
pass #TODO
|
xpu_available = True
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.is_available():
|
|
||||||
xpu_available = True
|
|
||||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
|
||||||
except:
|
|
||||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|
||||||
if not args.normalvram and not args.cpu:
|
|
||||||
if lowvram_available and total_vram <= 4096:
|
|
||||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
|
||||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
|
||||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
|
||||||
vram_state = VRAMState.HIGH_VRAM
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if directml_enabled:
|
||||||
|
global directml_device
|
||||||
|
return directml_device
|
||||||
|
if vram_state == VRAMState.MPS:
|
||||||
|
return torch.device("mps")
|
||||||
|
if vram_state == VRAMState.CPU:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if dev is None:
|
||||||
|
dev = get_torch_device()
|
||||||
|
|
||||||
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||||
|
mem_total = psutil.virtual_memory().total
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
else:
|
||||||
|
if directml_enabled:
|
||||||
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
elif xpu_available:
|
||||||
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
else:
|
||||||
|
stats = torch.cuda.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
||||||
|
mem_total_torch = mem_reserved
|
||||||
|
mem_total = mem_total_cuda
|
||||||
|
|
||||||
|
if torch_total_too:
|
||||||
|
return (mem_total, mem_total_torch)
|
||||||
|
else:
|
||||||
|
return mem_total
|
||||||
|
|
||||||
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
|
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
if not args.normalvram and not args.cpu:
|
||||||
|
if lowvram_available and total_vram <= 4096:
|
||||||
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||||
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||||
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
except:
|
except:
|
||||||
|
@ -128,29 +168,17 @@ if args.cpu:
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
def get_torch_device():
|
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
|
||||||
if directml_enabled:
|
|
||||||
global directml_device
|
|
||||||
return directml_device
|
|
||||||
if vram_state == VRAMState.MPS:
|
|
||||||
return torch.device("mps")
|
|
||||||
if vram_state == VRAMState.CPU:
|
|
||||||
return torch.device("cpu")
|
|
||||||
else:
|
|
||||||
if xpu_available:
|
|
||||||
return torch.device("xpu")
|
|
||||||
else:
|
|
||||||
return torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
return "{}".format(device.type)
|
if device.type == "cuda":
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
else:
|
||||||
|
return "{}".format(device.type)
|
||||||
|
else:
|
||||||
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Using device:", get_torch_device_name(get_torch_device()))
|
print("Device:", get_torch_device_name(get_torch_device()))
|
||||||
except:
|
except:
|
||||||
print("Could not pick default device.")
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
|
||||||
if dev is None:
|
|
||||||
dev = get_torch_device()
|
|
||||||
|
|
||||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
|
||||||
mem_total = psutil.virtual_memory().total
|
|
||||||
else:
|
|
||||||
if directml_enabled:
|
|
||||||
mem_total = 1024 * 1024 * 1024 #TODO
|
|
||||||
mem_total_torch = mem_total
|
|
||||||
elif xpu_available:
|
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
|
||||||
mem_total_torch = mem_total
|
|
||||||
else:
|
|
||||||
stats = torch.cuda.memory_stats(dev)
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
|
||||||
mem_total_torch = mem_reserved
|
|
||||||
mem_total = mem_total_cuda
|
|
||||||
|
|
||||||
if torch_total_too:
|
|
||||||
return (mem_total, mem_total_torch)
|
|
||||||
else:
|
|
||||||
return mem_total
|
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
|
@ -7,7 +7,6 @@ import execution
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
@ -284,9 +283,8 @@ class PromptServer():
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
device_index = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device = torch.device(device_index)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
device_name = comfy.model_management.get_torch_device_name(device_index)
|
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
system_stats = {
|
system_stats = {
|
||||||
|
|
Loading…
Reference in New Issue