initial minimal xla support
This commit is contained in:
parent
d9f90965c8
commit
e5182dd427
17
README.md
17
README.md
|
@ -200,6 +200,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
|||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### TPU/XLA Devices
|
||||
Users with TPU/XLA devices can install the PyTorch XLA stable build with the following command:
|
||||
|
||||
```pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html```
|
||||
|
||||
This is the command to install the nightly 2.6.0 which might have some performance improvements:
|
||||
|
||||
```
|
||||
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
```
|
||||
|
||||
|
||||
To get memory info for TPU devices, install the [tpu-info](https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info) package with the following command:
|
||||
|
||||
```pip install tpu-info```
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
|
|
@ -138,6 +138,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
|
||||
parser.add_argument("--xla", action="store_true", help="To use the XLA devices for everything.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ class CPUState(Enum):
|
|||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
XLA = 3
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
|
@ -84,9 +85,28 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if args.xla:
|
||||
import torch_xla as xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
cpu_state = CPUState.XLA
|
||||
|
||||
except ImportError:
|
||||
logging.error("XLA not available, please install pytorch-xla")
|
||||
pass
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
|
||||
def get_xla_memory_info(dev):
|
||||
# xm.get_memory_info(dev) only has bytes_limit and bytes_used
|
||||
mem_info = xm.get_memory_info(dev)
|
||||
mem_reserved = mem_info["bytes_used"]
|
||||
mem_total = mem_info["bytes_limit"]
|
||||
return (mem_reserved, mem_total)
|
||||
|
||||
|
||||
def is_intel_xpu():
|
||||
global cpu_state
|
||||
global xpu_available
|
||||
|
@ -105,6 +125,8 @@ def get_torch_device():
|
|||
return torch.device("mps")
|
||||
if cpu_state == CPUState.CPU:
|
||||
return torch.device("cpu")
|
||||
if cpu_state == CPUState.XLA:
|
||||
return xla.device()
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
|
@ -128,6 +150,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
elif cpu_state == CPUState.XLA:
|
||||
mem_total_torch, mem_total = get_xla_memory_info(dev)
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
|
@ -241,7 +265,7 @@ if lowvram_available:
|
|||
vram_state = set_vram_to
|
||||
|
||||
|
||||
if cpu_state != CPUState.GPU:
|
||||
if cpu_state != CPUState.GPU and cpu_state != CPUState.XLA:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
@ -924,6 +948,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
elif cpu_state == CPUState.XLA:
|
||||
mem_reserved, mem_total = get_xla_memory_info(dev)
|
||||
mem_free_total = mem_total - mem_reserved
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
|
@ -937,6 +965,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||
else:
|
||||
return mem_free_total
|
||||
|
||||
def xla_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.XLA
|
||||
|
||||
def cpu_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
@ -985,6 +1017,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if xla_mode():
|
||||
return True
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
@ -1044,6 +1079,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if xla_mode():
|
||||
return True
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
@ -1062,6 +1100,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if xla_mode():
|
||||
return False
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
|
@ -1101,7 +1142,7 @@ def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
|||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
# TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
|
|
|
@ -102,6 +102,9 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
|||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
if args.xla:
|
||||
import torch_xla as xla
|
||||
xla.sync()
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
|
|
Loading…
Reference in New Issue