Source code for src.model.fourier_positional_encoding
import math
import torch
import torch.nn as nn
[docs]
class FourierPositionalEncoding(nn.Module):
"""
Fourier positional encoding (Fourier feature mapping).
Maps low-dimensional continuous coordinates such as `(x, y)` into a
high-dimensional frequency space with sine and cosine functions. This helps
neural networks learn high-frequency details more effectively, as in NeRF.
"""
[docs]
def __init__(self, out_dim: int = 256, num_bands: int = 64, min_freq: float = 1e-3):
"""
Args:
out_dim (int): Final output embedding dimension.
num_bands (int): Number of frequency bands used.
min_freq (float): Minimum base frequency.
"""
super().__init__()
self.freqs = torch.linspace(min_freq, 0.5, num_bands)
self.proj = nn.Linear(num_bands * 4, out_dim)
[docs]
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): Input coordinates.
Shape: `(..., 2)`, assuming the last dimension is `(x, y)`.
Returns:
torch.Tensor: Positional encoding features.
Shape: `(..., out_dim)`.
"""
freqs = self.freqs.to(x.device) * math.pi * 2
x_proj = x[..., (0,)] * freqs # (..., num_bands)
y_proj = x[..., (1,)] * freqs # (..., num_bands)
fourier = torch.cat([
torch.sin(x_proj),
torch.cos(x_proj),
torch.sin(y_proj),
torch.cos(y_proj),
], dim=-1) # (..., num_bands*4)
fourier = torch.nan_to_num(fourier, nan=0.0)
pe = self.proj(fourier)
return pe