src.model.mean_pooling_lstm module#
- class src.model.mean_pooling_lstm.MeanPoolingLSTM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleLSTM-based temporal feature extractor.
Processes sequence data and uses the mean of the LSTM outputs across all time steps as the final sequence representation. Commonly used to encode pedestrian or vehicle trajectories.
- __init__(input_dim, embed_dim, layer_num)[source]#
- Parameters:
input_dim (int) – LSTM input feature dimension.
embed_dim (int) – LSTM hidden size, also the output dimension.
layer_num (int) – Number of LSTM layers.
- forward(x)[source]#
- Parameters:
x (torch.Tensor) – Input sequence tensor. Shape: (batch_size, num_agents, seq_len, input_dim) or any leading dimensions as long as the last two are (seq_len, input_dim).
- Returns:
- Pooled feature tensor.
Shape: (batch_size, num_agents, embed_dim), preserving all dimensions except the second-to-last seq_len dimension.
- Return type:
torch.Tensor