Clamp output when rounding weight to prevent Nan.

This commit is contained in:
comfyanonymous 2024-10-19 19:07:10 -04:00
parent 518c0dc2fe
commit 73e3a9e676
1 changed files with 2 additions and 0 deletions

View File

@ -41,6 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign