diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 92745153..f37f7ff7 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from .. import attention from einops import rearrange, repeat +from .util import timestep_embedding def default(x, y): if x is not None: @@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module): ) self.frequency_embedding_size = frequency_embedding_size - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) - / half - ) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - if torch.is_floating_point(t): - embedding = embedding.to(dtype=t.dtype) - return embedding - def forward(self, t, dtype, **kwargs): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_emb = self.mlp(t_freq) return t_emb