Do RMSNorm in native type.

This commit is contained in:
comfyanonymous 2024-08-27 02:41:56 -04:00
parent ca4b8f30e0
commit ab130001a8
1 changed files with 1 additions and 3 deletions

View File

@ -63,10 +63,8 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor): def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device) return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
class QKNorm(torch.nn.Module): class QKNorm(torch.nn.Module):