src.model.nan_embedding module

src.model.nan_embedding module#

class src.model.nan_embedding.NanEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

Linear embedding layer with explicit NaN handling.

Instead of passing NaN inputs through the normal linear map, this layer replaces them with a dedicated learnable vector (nan_embed). This is useful for missing observations in trajectory data.

__init__(input_dim, embed_dim, disable=False)[source]#
Parameters:
  • input_dim (int) – Input feature dimension.

  • embed_dim (int) – Output embedding dimension.

  • disable (bool) – Whether to disable the NaN embedding behavior.

forward(x)[source]#
Parameters:

x (torch.Tensor) – Input tensor, possibly containing NaN values. Shape: (…, input_dim)

Returns:

Embedded tensor with NaN positions replaced.

Shape: (…, embed_dim)

Return type:

torch.Tensor