initial minimal xla support

This commit is contained in:
radna0 2024-11-17 22:45:50 +00:00
parent d9f90965c8
commit e5182dd427
4 changed files with 65 additions and 2 deletions

View File

@ -200,6 +200,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### TPU/XLA Devices
Users with TPU/XLA devices can install the PyTorch XLA stable build with the following command:
```pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html```
This is the command to install the nightly 2.6.0 which might have some performance improvements:
```
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
```
To get memory info for TPU devices, install the [tpu-info](https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info) package with the following command:
```pip install tpu-info```
# Running # Running
```python main.py``` ```python main.py```

View File

@ -138,6 +138,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--xla", action="store_true", help="To use the XLA devices for everything.")
# The default built-in provider hosted under web/ # The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@ -36,6 +36,7 @@ class CPUState(Enum):
GPU = 0 GPU = 0
CPU = 1 CPU = 1
MPS = 2 MPS = 2
XLA = 3
# Determine VRAM State # Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
@ -84,9 +85,28 @@ try:
except: except:
pass pass
try:
if args.xla:
import torch_xla as xla
import torch_xla.core.xla_model as xm
cpu_state = CPUState.XLA
except ImportError:
logging.error("XLA not available, please install pytorch-xla")
pass
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
def get_xla_memory_info(dev):
# xm.get_memory_info(dev) only has bytes_limit and bytes_used
mem_info = xm.get_memory_info(dev)
mem_reserved = mem_info["bytes_used"]
mem_total = mem_info["bytes_limit"]
return (mem_reserved, mem_total)
def is_intel_xpu(): def is_intel_xpu():
global cpu_state global cpu_state
global xpu_available global xpu_available
@ -105,6 +125,8 @@ def get_torch_device():
return torch.device("mps") return torch.device("mps")
if cpu_state == CPUState.CPU: if cpu_state == CPUState.CPU:
return torch.device("cpu") return torch.device("cpu")
if cpu_state == CPUState.XLA:
return xla.device()
else: else:
if is_intel_xpu(): if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
@ -128,6 +150,8 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
elif cpu_state == CPUState.XLA:
mem_total_torch, mem_total = get_xla_memory_info(dev)
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -241,7 +265,7 @@ if lowvram_available:
vram_state = set_vram_to vram_state = set_vram_to
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU and cpu_state != CPUState.XLA:
vram_state = VRAMState.DISABLED vram_state = VRAMState.DISABLED
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
@ -924,6 +948,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch mem_free_total = mem_free_xpu + mem_free_torch
elif cpu_state == CPUState.XLA:
mem_reserved, mem_total = get_xla_memory_info(dev)
mem_free_total = mem_total - mem_reserved
mem_free_torch = mem_free_total
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -937,6 +965,10 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def xla_mode():
global cpu_state
return cpu_state == CPUState.XLA
def cpu_mode(): def cpu_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.CPU return cpu_state == CPUState.CPU
@ -985,6 +1017,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if cpu_mode(): if cpu_mode():
return False return False
if xla_mode():
return True
if is_intel_xpu(): if is_intel_xpu():
return True return True
@ -1044,6 +1079,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if cpu_mode(): if cpu_mode():
return False return False
if xla_mode():
return True
if is_intel_xpu(): if is_intel_xpu():
return True return True
@ -1062,6 +1100,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def supports_fp8_compute(device=None): def supports_fp8_compute(device=None):
if xla_mode():
return False
if not is_nvidia(): if not is_nvidia():
return False return False
@ -1101,7 +1142,7 @@ def resolve_lowvram_weight(weight, model, key): #TODO: remove
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight return weight
#TODO: might be cleaner to put this somewhere else # TODO: might be cleaner to put this somewhere else
import threading import threading
class InterruptProcessingException(Exception): class InterruptProcessingException(Exception):

View File

@ -102,6 +102,9 @@ def prepare_callback(model, steps, x0_output_dict=None):
preview_bytes = None preview_bytes = None
if previewer: if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
if args.xla:
import torch_xla as xla
xla.sync()
pbar.update_absolute(step + 1, total_steps, preview_bytes) pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback return callback