Remove duplicate code.
This commit is contained in:
parent
0a4c49c57c
commit
10b43ceea5
|
@ -7,6 +7,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .. import attention
|
from .. import attention
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
from .util import timestep_embedding
|
||||||
|
|
||||||
def default(x, y):
|
def default(x, y):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
|
@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
|
||||||
)
|
)
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def timestep_embedding(t, dim, max_period=10000):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period)
|
|
||||||
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
|
||||||
/ half
|
|
||||||
)
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat(
|
|
||||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
||||||
)
|
|
||||||
if torch.is_floating_point(t):
|
|
||||||
embedding = embedding.to(dtype=t.dtype)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def forward(self, t, dtype, **kwargs):
|
def forward(self, t, dtype, **kwargs):
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||||
t_emb = self.mlp(t_freq)
|
t_emb = self.mlp(t_freq)
|
||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue