From ab130001a8b966ed788f7436aa3b689d038e42a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Aug 2024 02:41:56 -0400 Subject: [PATCH] Do RMSNorm in native type. --- comfy/ldm/flux/layers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 9820832b..20bd2850 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -63,10 +63,8 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - x_dtype = x.dtype - x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device) + return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device) class QKNorm(torch.nn.Module):