Fix to get fp8 working on T5 base.
This commit is contained in:
parent
a5991a7aa6
commit
c24f897352
|
@ -236,4 +236,6 @@ class T5(torch.nn.Module):
|
|||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue