Fix bfloat16 potentially not being enabled on mps.
This commit is contained in:
parent
48eb1399c0
commit
a6decf1e62
|
@ -897,7 +897,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
if mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
|
|
Loading…
Reference in New Issue