diff --git a/comfy/model_management.py b/comfy/model_management.py index e9af7f3a..3b7b1dbf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -308,6 +308,33 @@ def pytorch_attention_flash_attention(): return True return False +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + mem_total_torch + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index 0b64df14..acbc88f6 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import torch from PIL import Image from io import BytesIO @@ -23,6 +24,7 @@ except ImportError: import mimetypes from comfy.cli_args import args import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -280,6 +282,28 @@ class PromptServer(): return web.Response(status=404) return web.json_response(dt["__metadata__"]) + @routes.get("/system_stats") + async def get_queue(request): + device_index = comfy.model_management.get_torch_device() + device = torch.device(device_index) + device_name = comfy.model_management.get_torch_device_name(device_index) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info())