initial minimal xla support
This commit is contained in:
parent
d9f90965c8
commit
e5182dd427
17
README.md
17
README.md
|
@ -200,6 +200,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||||
|
|
||||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||||
|
|
||||||
|
#### TPU/XLA Devices
|
||||||
|
Users with TPU/XLA devices can install the PyTorch XLA stable build with the following command:
|
||||||
|
|
||||||
|
```pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html```
|
||||||
|
|
||||||
|
This is the command to install the nightly 2.6.0 which might have some performance improvements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||||
|
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
To get memory info for TPU devices, install the [tpu-info](https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info) package with the following command:
|
||||||
|
|
||||||
|
```pip install tpu-info```
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|
|
@ -138,6 +138,8 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||||
|
|
||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
|
||||||
|
parser.add_argument("--xla", action="store_true", help="To use the XLA devices for everything.")
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ class CPUState(Enum):
|
||||||
GPU = 0
|
GPU = 0
|
||||||
CPU = 1
|
CPU = 1
|
||||||
MPS = 2
|
MPS = 2
|
||||||
|
XLA = 3
|
||||||
|
|
||||||
# Determine VRAM State
|
# Determine VRAM State
|
||||||
vram_state = VRAMState.NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
|
@ -84,9 +85,28 @@ try:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.xla:
|
||||||
|
import torch_xla as xla
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
cpu_state = CPUState.XLA
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logging.error("XLA not available, please install pytorch-xla")
|
||||||
|
pass
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
|
|
||||||
|
def get_xla_memory_info(dev):
|
||||||
|
# xm.get_memory_info(dev) only has bytes_limit and bytes_used
|
||||||
|
mem_info = xm.get_memory_info(dev)
|
||||||
|
mem_reserved = mem_info["bytes_used"]
|
||||||
|
mem_total = mem_info["bytes_limit"]
|
||||||
|
return (mem_reserved, mem_total)
|
||||||
|
|
||||||
|
|
||||||
def is_intel_xpu():
|
def is_intel_xpu():
|
||||||
global cpu_state
|
global cpu_state
|
||||||
global xpu_available
|
global xpu_available
|
||||||
|
@ -105,6 +125,8 @@ def get_torch_device():
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if cpu_state == CPUState.CPU:
|
if cpu_state == CPUState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
if cpu_state == CPUState.XLA:
|
||||||
|
return xla.device()
|
||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
|
@ -128,6 +150,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
elif cpu_state == CPUState.XLA:
|
||||||
|
mem_total_torch, mem_total = get_xla_memory_info(dev)
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
@ -241,7 +265,7 @@ if lowvram_available:
|
||||||
vram_state = set_vram_to
|
vram_state = set_vram_to
|
||||||
|
|
||||||
|
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU and cpu_state != CPUState.XLA:
|
||||||
vram_state = VRAMState.DISABLED
|
vram_state = VRAMState.DISABLED
|
||||||
|
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
@ -924,6 +948,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
|
elif cpu_state == CPUState.XLA:
|
||||||
|
mem_reserved, mem_total = get_xla_memory_info(dev)
|
||||||
|
mem_free_total = mem_total - mem_reserved
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
@ -937,6 +965,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||||
else:
|
else:
|
||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
def xla_mode():
|
||||||
|
global cpu_state
|
||||||
|
return cpu_state == CPUState.XLA
|
||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global cpu_state
|
global cpu_state
|
||||||
return cpu_state == CPUState.CPU
|
return cpu_state == CPUState.CPU
|
||||||
|
@ -985,6 +1017,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if xla_mode():
|
||||||
|
return True
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
@ -1044,6 +1079,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if xla_mode():
|
||||||
|
return True
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
@ -1062,6 +1100,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if xla_mode():
|
||||||
|
return False
|
||||||
|
|
||||||
if not is_nvidia():
|
if not is_nvidia():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1101,7 +1142,7 @@ def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
# TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
class InterruptProcessingException(Exception):
|
||||||
|
|
|
@ -102,6 +102,9 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
preview_bytes = None
|
preview_bytes = None
|
||||||
if previewer:
|
if previewer:
|
||||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||||
|
if args.xla:
|
||||||
|
import torch_xla as xla
|
||||||
|
xla.sync()
|
||||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue