parent
30f274bf48
commit
37713e3b0a
|
@ -5,6 +5,7 @@ LOW_VRAM = 2
|
|||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
MPS = 5
|
||||
XPU = 6
|
||||
|
||||
accelerate_enabled = False
|
||||
vram_state = NORMAL_VRAM
|
||||
|
@ -85,10 +86,17 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
if torch.xpu.is_available():
|
||||
vram_state = XPU
|
||||
except:
|
||||
pass
|
||||
|
||||
if forced_cpu:
|
||||
vram_state = CPU
|
||||
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state])
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
|
@ -141,6 +149,9 @@ def load_model_gpu(model):
|
|||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_state == XPU:
|
||||
real_model.to("xpu")
|
||||
pass
|
||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
||||
model_accelerated = False
|
||||
real_model.cuda()
|
||||
|
@ -189,6 +200,8 @@ def unload_if_low_vram(model):
|
|||
def get_torch_device():
|
||||
if vram_state == MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == XPU:
|
||||
return torch.device("xpu")
|
||||
if vram_state == CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
|
@ -228,6 +241,9 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
elif hasattr(dev, 'type') and (dev.type == 'xpu'):
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
|
@ -258,8 +274,12 @@ def mps_mode():
|
|||
global vram_state
|
||||
return vram_state == MPS
|
||||
|
||||
def xpu_mode():
|
||||
global vram_state
|
||||
return vram_state == XPU
|
||||
|
||||
def should_use_fp16():
|
||||
if cpu_mode() or mps_mode():
|
||||
if cpu_mode() or mps_mode() or xpu_mode():
|
||||
return False #TODO ?
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
|
|
Loading…
Reference in New Issue