Remove duplicate code.
This commit is contained in:
parent
0a4c49c57c
commit
10b43ceea5
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
|
@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
|
|||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
/ half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
|
Loading…
Reference in New Issue