Clamp output when rounding weight to prevent Nan.
This commit is contained in:
parent
518c0dc2fe
commit
73e3a9e676
|
@ -41,6 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
|
||||
inf = torch.finfo(dtype)
|
||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||
return sign
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue