Source code for src.model.sinusoidal_embedding
import math
import torch
import torch.nn as nn
[docs]
class SinusoidalEmbedding(nn.Module):
"""
Sinusoidal embedding generator.
Commonly used in diffusion models to map discrete timesteps into
high-dimensional continuous feature vectors. Includes a sinusoidal encoding
stage followed by an MLP projection.
"""
[docs]
def __init__(self, embed_dim):
"""
Args:
embed_dim (int): Output embedding dimension.
"""
super().__init__()
self.embed_dim = embed_dim
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Linear(embed_dim * 4, embed_dim),
)
self.half_dim = self.embed_dim // 2
self.freq = torch.exp(
-torch.arange(self.half_dim).float()
* (math.log(10000.0) / (self.half_dim - 1))
)[None, :]
[docs]
def forward(self, t: torch.LongTensor):
"""
Args:
t (torch.LongTensor): Timestep indices.
Shape: (batch_size, )
Returns:
torch.Tensor: Timestep embeddings.
Shape: (batch_size, embed_dim)
"""
emb = t[:, None].float() * self.freq.to(t.device) # (batch, half_dim)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # (batch, embed_dim)
return self.mlp(emb) # (batch, embed_dim)