System stats endpoint

This commit is contained in:
space-nuko 2023-06-01 23:26:23 -05:00
parent 1bbd3f7fe1
commit b5dd15c67a
2 changed files with 51 additions and 0 deletions

View File

@ -308,6 +308,33 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda + mem_total_torch
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled global directml_enabled

View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -23,6 +24,7 @@ except ImportError:
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
@ -280,6 +282,28 @@ class PromptServer():
return web.Response(status=404) return web.Response(status=404)
return web.json_response(dt["__metadata__"]) return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device_index = comfy.model_management.get_torch_device()
device = torch.device(device_index)
device_name = comfy.model_management.get_torch_device_name(device_index)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())