Source code for src.model.mean_pooling_lstm
import torch.nn as nn
[docs]
class MeanPoolingLSTM(nn.Module):
"""
LSTM-based temporal feature extractor.
Processes sequence data and uses the mean of the LSTM outputs across all
time steps as the final sequence representation. Commonly used to encode
pedestrian or vehicle trajectories.
"""
[docs]
def __init__(self, input_dim, embed_dim, layer_num):
"""
Args:
input_dim (int): LSTM input feature dimension.
embed_dim (int): LSTM hidden size, also the output dimension.
layer_num (int): Number of LSTM layers.
"""
super().__init__()
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=embed_dim,
num_layers=layer_num,
batch_first=True,
)
[docs]
def forward(self, x):
"""
Args:
x (torch.Tensor): Input sequence tensor.
Shape: (batch_size, num_agents, seq_len, input_dim)
or any leading dimensions as long as the last two are
`(seq_len, input_dim)`.
Returns:
torch.Tensor: Pooled feature tensor.
Shape: (batch_size, num_agents, embed_dim), preserving all
dimensions except the second-to-last `seq_len` dimension.
"""
shape = x.shape
x = x.view(-1, *shape[-2:]) # (batch_size * N, seq_len, input_dim)
out, _ = self.lstm(x) # (batch_size, seq_len, embed_dim)
out = out.mean(dim=-2) # (batch_size, embed_dim)
out = out.view(*shape[:-2], shape[-1]) # (batch_size, N, embed_dim)
return out