Remove duplicate code.

This commit is contained in:
comfyanonymous 2024-07-24 01:12:59 -04:00
parent 0a4c49c57c
commit 10b43ceea5
1 changed files with 2 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .. import attention from .. import attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding
def default(x, y): def default(x, y):
if x is not None: if x is not None:
@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs): def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq) t_emb = self.mlp(t_freq)
return t_emb return t_emb