Fix to get fp8 working on T5 base.

This commit is contained in:
comfyanonymous 2024-07-31 02:00:19 -04:00
parent a5991a7aa6
commit c24f897352
1 changed files with 2 additions and 0 deletions

View File

@ -236,4 +236,6 @@ class T5(torch.nn.Module):
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)