Do RMSNorm in native type.
This commit is contained in:
parent
ca4b8f30e0
commit
ab130001a8
|
@ -63,10 +63,8 @@ class RMSNorm(torch.nn.Module):
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue