diff --git a/README.md b/README.md index 8e5cef45..f2e54f14 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,16 @@ After this you should have everything installed and can proceed to running Comfy ### Others: +#### Ascend NPUs + +We offer Ascend NPU support for all models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: + +1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary. +2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform. +3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page. +4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier. + + #### Intel GPUs Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows: diff --git a/comfy/model_management.py b/comfy/model_management.py index 855e8911..580899b3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -84,6 +84,13 @@ try: except: pass +try: + import torch_npu + _ = torch.npu.device_count() + npu_available = torch.npu.is_available() +except: + npu_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -95,6 +102,12 @@ def is_intel_xpu(): return True return False +def is_ascend_npu(): + global npu_available + if npu_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -108,6 +121,8 @@ def get_torch_device(): else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) + elif is_ascend_npu(): + return torch.device("npu", torch.npu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -128,6 +143,12 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_npu = torch.npu.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_npu else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -200,13 +221,13 @@ try: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - if is_intel_xpu(): + if is_intel_xpu() or is_ascend_npu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass -if is_intel_xpu(): +if is_intel_xpu() or is_ascend_npu(): VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if args.cpu_vae: @@ -266,6 +287,8 @@ def get_torch_device_name(device): return "{}".format(device.type) elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) + elif is_ascend_npu(): + return "{} {}".format(device, torch.npu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -866,6 +889,8 @@ def xformers_enabled(): return False if is_intel_xpu(): return False + if is_ascend_npu(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -890,6 +915,8 @@ def pytorch_attention_flash_attention(): return True if is_intel_xpu(): return True + if is_ascend_npu(): + return True return False def force_upcast_attention_dtype(): @@ -924,6 +951,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_reserved - mem_active mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_total = mem_free_xpu + mem_free_torch + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_npu, _ = torch.npu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_npu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -988,6 +1022,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + + if is_ascend_npu(): + return True if torch.version.hip: return True @@ -1088,6 +1125,8 @@ def soft_empty_cache(force=False): torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() + elif is_ascend_npu(): + torch.npu.empty_cache() elif torch.cuda.is_available(): if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache()