Source code for src.model.nan_embedding
import torch
import torch.nn as nn
[docs]
class NanEmbedding(nn.Module):
"""
Linear embedding layer with explicit NaN handling.
Instead of passing NaN inputs through the normal linear map, this layer
replaces them with a dedicated learnable vector (`nan_embed`). This is
useful for missing observations in trajectory data.
"""
[docs]
def __init__(self, input_dim, embed_dim, disable=False):
"""
Args:
input_dim (int): Input feature dimension.
embed_dim (int): Output embedding dimension.
disable (bool): Whether to disable the NaN embedding behavior.
"""
super().__init__()
self.embed = nn.Linear(input_dim, embed_dim) # Mapping for regular values.
self.nan_embed = nn.Parameter(torch.randn(embed_dim)) # Learnable vector for NaN inputs.
self.disable = disable
[docs]
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor, possibly containing NaN values.
Shape: (..., input_dim)
Returns:
torch.Tensor: Embedded tensor with NaN positions replaced.
Shape: (..., embed_dim)
"""
nan_mask = torch.isnan(x).all(dim=-1)
x = torch.nan_to_num(x, nan=0.0)
out = self.embed(x) # (N, L, D)
if not self.disable:
out[nan_mask, :] = self.nan_embed
else:
out[nan_mask, :] = 0.0
return out