Add basic XPU device support

closed #387
This commit is contained in:
藍+85CD 2023-04-05 21:22:14 +08:00
parent 30f274bf48
commit 37713e3b0a
1 changed files with 22 additions and 2 deletions

View File

@ -5,6 +5,7 @@ LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
XPU = 6
accelerate_enabled = False
vram_state = NORMAL_VRAM
@ -85,10 +86,17 @@ try:
except:
pass
try:
import intel_extension_for_pytorch
if torch.xpu.is_available():
vram_state = XPU
except:
pass
if forced_cpu:
vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state])
current_loaded_model = None
@ -141,6 +149,9 @@ def load_model_gpu(model):
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == XPU:
real_model.to("xpu")
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False
real_model.cuda()
@ -189,6 +200,8 @@ def unload_if_low_vram(model):
def get_torch_device():
if vram_state == MPS:
return torch.device("mps")
if vram_state == XPU:
return torch.device("xpu")
if vram_state == CPU:
return torch.device("cpu")
else:
@ -228,6 +241,9 @@ def get_free_memory(dev=None, torch_free_too=False):
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and (dev.type == 'xpu'):
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -258,8 +274,12 @@ def mps_mode():
global vram_state
return vram_state == MPS
def xpu_mode():
global vram_state
return vram_state == XPU
def should_use_fp16():
if cpu_mode() or mps_mode():
if cpu_mode() or mps_mode() or xpu_mode():
return False #TODO ?
if torch.cuda.is_bf16_supported():