From e5182dd427cd41c58dabc5cd683e586238941b5a Mon Sep 17 00:00:00 2001 From: radna0 Date: Sun, 17 Nov 2024 22:45:50 +0000 Subject: [PATCH] initial minimal xla support --- README.md | 17 +++++++++++++++ comfy/cli_args.py | 2 ++ comfy/model_management.py | 45 +++++++++++++++++++++++++++++++++++++-- latent_preview.py | 3 +++ 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7476b578..8f336daf 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` +#### TPU/XLA Devices +Users with TPU/XLA devices can install the PyTorch XLA stable build with the following command: + +```pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html``` + +This is the command to install the nightly 2.6.0 which might have some performance improvements: + +``` +pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html +``` + + +To get memory info for TPU devices, install the [tpu-info](https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info) package with the following command: + +```pip install tpu-info``` + # Running ```python main.py``` diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 20b9f474..3c66e75b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -138,6 +138,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') +parser.add_argument("--xla", action="store_true", help="To use the XLA devices for everything.") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/comfy/model_management.py b/comfy/model_management.py index fd493aff..d06be1b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -36,6 +36,7 @@ class CPUState(Enum): GPU = 0 CPU = 1 MPS = 2 + XLA = 3 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM @@ -84,9 +85,28 @@ try: except: pass +try: + if args.xla: + import torch_xla as xla + import torch_xla.core.xla_model as xm + cpu_state = CPUState.XLA + +except ImportError: + logging.error("XLA not available, please install pytorch-xla") + pass + if args.cpu: cpu_state = CPUState.CPU + +def get_xla_memory_info(dev): + # xm.get_memory_info(dev) only has bytes_limit and bytes_used + mem_info = xm.get_memory_info(dev) + mem_reserved = mem_info["bytes_used"] + mem_total = mem_info["bytes_limit"] + return (mem_reserved, mem_total) + + def is_intel_xpu(): global cpu_state global xpu_available @@ -105,6 +125,8 @@ def get_torch_device(): return torch.device("mps") if cpu_state == CPUState.CPU: return torch.device("cpu") + if cpu_state == CPUState.XLA: + return xla.device() else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) @@ -128,6 +150,8 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + elif cpu_state == CPUState.XLA: + mem_total_torch, mem_total = get_xla_memory_info(dev) else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -241,7 +265,7 @@ if lowvram_available: vram_state = set_vram_to -if cpu_state != CPUState.GPU: +if cpu_state != CPUState.GPU and cpu_state != CPUState.XLA: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: @@ -924,6 +948,10 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_reserved - mem_active mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_total = mem_free_xpu + mem_free_torch + elif cpu_state == CPUState.XLA: + mem_reserved, mem_total = get_xla_memory_info(dev) + mem_free_total = mem_total - mem_reserved + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -937,6 +965,10 @@ def get_free_memory(dev=None, torch_free_too=False): else: return mem_free_total +def xla_mode(): + global cpu_state + return cpu_state == CPUState.XLA + def cpu_mode(): global cpu_state return cpu_state == CPUState.CPU @@ -985,6 +1017,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if cpu_mode(): return False + + if xla_mode(): + return True if is_intel_xpu(): return True @@ -1044,6 +1079,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if cpu_mode(): return False + + if xla_mode(): + return True if is_intel_xpu(): return True @@ -1062,6 +1100,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): + if xla_mode(): + return False + if not is_nvidia(): return False @@ -1101,7 +1142,7 @@ def resolve_lowvram_weight(weight, model, key): #TODO: remove print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") return weight -#TODO: might be cleaner to put this somewhere else +# TODO: might be cleaner to put this somewhere else import threading class InterruptProcessingException(Exception): diff --git a/latent_preview.py b/latent_preview.py index d60e68d5..9dbfe11c 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -102,6 +102,9 @@ def prepare_callback(model, steps, x0_output_dict=None): preview_bytes = None if previewer: preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + if args.xla: + import torch_xla as xla + xla.sync() pbar.update_absolute(step + 1, total_steps, preview_bytes) return callback