Small optimization.

This commit is contained in:
comfyanonymous 2024-06-15 02:44:38 -04:00
parent f2e844e054
commit 1281f933c1
1 changed files with 2 additions and 2 deletions

View File

@ -243,9 +243,9 @@ class TimestepEmbedder(nn.Module):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
).to(device=t.device)
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: