Mac supports bf16 just make sure you are using the latest pytorch.
This commit is contained in:
parent
e2382b6adb
commit
7ad574bffd
|
@ -649,12 +649,12 @@ def supports_cast(device, dtype): #TODO
|
|||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return True
|
||||
if dtype == torch.float8_e5m2:
|
||||
|
@ -876,9 +876,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue