diff --git a/comfy/model_management.py b/comfy/model_management.py index 0ea0c71e..9c3147d7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): CPU = 0 @@ -33,28 +34,67 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - pass #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -128,29 +168,17 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") @@ -308,33 +336,6 @@ def pytorch_attention_flash_attention(): return True return False -def get_total_memory(dev=None, torch_total_too=False): - global xpu_available - global directml_enabled - if dev is None: - dev = get_torch_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total - else: - if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO - mem_total_torch = mem_total - elif xpu_available: - mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total - else: - stats = torch.cuda.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_cuda = torch.cuda.mem_get_info(dev) - mem_total_torch = mem_reserved - mem_total = mem_total_cuda - - if torch_total_too: - return (mem_total, mem_total_torch) - else: - return mem_total - def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index acbc88f6..5be822a6 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob -import torch from PIL import Image from io import BytesIO @@ -284,9 +283,8 @@ class PromptServer(): @routes.get("/system_stats") async def get_queue(request): - device_index = comfy.model_management.get_torch_device() - device = torch.device(device_index) - device_name = comfy.model_management.get_torch_device_name(device_index) + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) system_stats = {