diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 9820832b..20bd2850 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -63,10 +63,8 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - x_dtype = x.dtype - x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device) + return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device) class QKNorm(torch.nn.Module):