Source code for src.model.model

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.utils.timer import NamedTimer
from .sinusoidal_embedding import SinusoidalEmbedding
from .fourier_positional_encoding import FourierPositionalEncoding
from .residual import Residual
from .nan_embedding import NanEmbedding
from .permuted import Permuted
from .mean_pooling_lstm import MeanPoolingLSTM


[docs] class Model(nn.Module): """ Pedestrian trajectory prediction model based on Transformer and diffusion. The model combines pedestrian history, social interaction with nearby pedestrians, interaction with surrounding vehicles, and static map context to predict motion intent, either acceleration or noise, during reverse denoising in DDPM/DDIM. """
[docs] def __init__(self, args): """ Initialize model layers and embedding modules. Args: args (Namespace): Configuration object containing: - model_dim (int): Internal model feature dimension. - map_feature_dim (int): Intermediate map feature dimension. - lstm_layer_num (int): Number of LSTM layers for temporal data. - head_num (int): Number of attention heads. - attention_layer_num (int): Number of Transformer decoder layers. - latent_token_num (int): Number of latent tokens used to compress map features. - dropout (float): Dropout ratio. - pred_step (int): Prediction horizon. - use_spatial_anchor (bool): Whether to enhance map positional encoding with spatial anchors. """ super().__init__() self.args = args self.denoise_t_embedder = SinusoidalEmbedding(args.model_dim) self.noisy_acc_embedder = nn.Sequential( nn.Linear(2, args.model_dim), MeanPoolingLSTM(args.model_dim, args.model_dim, args.lstm_layer_num), nn.LayerNorm(args.model_dim), ) self.positional_encoding = FourierPositionalEncoding(out_dim=args.model_dim, num_bands=args.model_dim, min_freq=1e-3) self.pos_embedder = nn.Sequential( nn.Linear(2, args.model_dim), nn.LayerNorm(args.model_dim), ) self.vel_embedder = nn.Sequential( nn.Linear(2, args.model_dim), nn.LayerNorm(args.model_dim), ) self.hst_embedder = nn.Sequential( NanEmbedding(2, args.model_dim, disable=not args.use_nan_embedding), MeanPoolingLSTM(args.model_dim, args.model_dim, args.lstm_layer_num), nn.LayerNorm(args.model_dim), ) self.des_embedder = nn.Sequential( NanEmbedding(2, args.model_dim, disable=not args.use_nan_embedding), nn.LayerNorm(args.model_dim), ) self.spd_embedder = nn.Sequential( NanEmbedding(1, args.model_dim, disable=not args.use_nan_embedding), nn.LayerNorm(args.model_dim), ) self.ped_encoder = nn.Sequential( nn.LayerNorm(args.model_dim), nn.Linear(args.model_dim, 4*args.model_dim), nn.ReLU(), nn.Linear(4*args.model_dim, args.model_dim), ) self.veh_embedder = nn.Sequential( NanEmbedding(2, args.model_dim, disable=not args.use_nan_embedding), MeanPoolingLSTM(args.model_dim, args.model_dim, args.lstm_layer_num), nn.LayerNorm(args.model_dim), ) self.map_embedder = nn.Sequential( NanEmbedding(1, args.map_feature_dim//4, disable=not args.use_nan_embedding), Permuted(2, 0, 1), # (H, W, C) -> (C, H, W) nn.Conv2d(args.map_feature_dim//4, args.map_feature_dim//2, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(args.map_feature_dim//2, args.map_feature_dim, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(args.map_feature_dim, args.model_dim, kernel_size=1), # 1x1 convolution, equivalent to a per-location Linear(C->Df). Permuted(1, 2, 0), # (C, H, W) -> (H, W, C) nn.LayerNorm(args.model_dim), ) self.ped_attention = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=args.model_dim, nhead=args.head_num, dim_feedforward=4*args.model_dim, dropout=args.dropout, activation='relu', batch_first=True, norm_first=True, ), num_layers=args.attention_layer_num, ) self.veh_attention = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=args.model_dim, nhead=args.head_num, dim_feedforward=4*args.model_dim, dropout=args.dropout, activation='relu', batch_first=True, norm_first=True, ), num_layers=args.attention_layer_num, ) self.map_attention = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=args.model_dim, nhead=args.head_num, dim_feedforward=4*args.model_dim, dropout=args.dropout, activation='relu', batch_first=True, norm_first=True, ), num_layers=args.attention_layer_num, ) self.latent_attntn = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=args.model_dim, nhead=args.head_num, dim_feedforward=4*args.model_dim, dropout=args.dropout, activation='relu', batch_first=True, norm_first=True, ), num_layers=args.attention_layer_num, ) self.latent_tokens = nn.Parameter( torch.randn(args.latent_token_num, args.model_dim) ) if args.use_spatial_anchor: # Ensure the token count is a square number, e.g. 16 or 64. grid_size = int(math.sqrt(args.latent_token_num)) if grid_size ** 2 != args.latent_token_num: raise ValueError(f"latent_token_num ({args.latent_token_num}) must be a square number when use_spatial_anchor is True.") self.grid_size = grid_size # Build a normalized 0~1 coordinate grid for later mapping to physical space. # Register it as a buffer so it is saved in the state dict but not updated as a parameter. x = torch.linspace(0, 1, grid_size) y = torch.linspace(0, 1, grid_size) xx, yy = torch.meshgrid(x, y, indexing='ij') anchor_norm = torch.stack([xx, yy], dim=-1) # (S, S, 2) self.register_buffer('anchor_norm', anchor_norm) self.fusion_fc = Residual( nn.LayerNorm(args.model_dim), nn.Linear(args.model_dim, 4*args.model_dim), nn.ReLU(), nn.Linear(4*args.model_dim, args.model_dim), ) self.output_fc = Residual( nn.LayerNorm(args.model_dim), nn.Linear(args.model_dim, args.model_dim//2), nn.ReLU(), nn.Linear(args.model_dim//2, args.pred_step*2), input_dim=args.model_dim, output_dim=args.pred_step*2, )
[docs] def set_ped_embedding( self, pos: torch.FloatTensor, vel: torch.FloatTensor, hst: torch.FloatTensor, des: torch.FloatTensor, spd: torch.FloatTensor, ): """ Compute and store the joint pedestrian embedding. This method maps pedestrian position, velocity, history, destination, and desired speed into a high-dimensional space and sums them to form the initial pedestrian feature vector. It also computes positional encoding. Args: pos (torch.FloatTensor): Current pedestrian coordinates `(x, y)`. Shape: (batch_size, num_peds, 2) vel (torch.FloatTensor): Current pedestrian velocity `(vx, vy)`. Shape: (batch_size, num_peds, 2) hst (torch.FloatTensor): Pedestrian history trajectory sequence. Shape: (batch_size, num_peds, hist_step, 2) des (torch.FloatTensor): Pedestrian destination coordinates. Shape: (batch_size, num_peds, 2) spd (torch.FloatTensor): Pedestrian desired speed scalar. Shape: (batch_size, num_peds, 1) Side Effects: Sets `self.ped_embedding`: fused pedestrian features. Sets `self.pos`: cached current positions for later map indexing. Sets `self.pe`: positional encoding features. """ pos_embedding = self.pos_embedder(pos) # (batch_size, #pedestrian, model_dim) vel_embedding = self.vel_embedder(vel) # (batch_size, #pedestrian, model_dim) hst_embedding = self.hst_embedder(hst) # (batch_size, #pedestrian, model_dim) des_embedding = self.des_embedder(des) # (batch_size, #pedestrian, model_dim) spd_embedding = self.spd_embedder(spd) # (batch_size, #pedestrian, model_dim) ped_embedding = pos_embedding + vel_embedding + hst_embedding + des_embedding + spd_embedding self.ped_embedding = ped_embedding self.pos = pos pe = self.positional_encoding(pos) # (batch_size, #pedestrian, model_dim) self.pe = pe
[docs] def set_veh_embedding( self, veh: torch.FloatTensor, ): """ Compute and store vehicle feature embeddings. Processes vehicle trajectory history with an LSTM. If the current scene contains no vehicles, NaN padding is inserted automatically. Args: veh (torch.FloatTensor): Vehicle history trajectory sequence. Shape: (batch_size, num_vehs, hist_step + 1, 2) Side Effects: Sets `self.veh_embedding`: vehicle feature vectors. """ shape = list(veh.shape) if shape[1] == 0: shape[1] = 1 veh = torch.full(shape, float('nan'), device=veh.device) veh_embedding = self.veh_embedder(veh) # (batch_size, #vehicle, model_dim) self.veh_embedding = veh_embedding
[docs] def set_map_embedding( self, map: torch.FloatTensor, xmin: torch.FloatTensor, xmax: torch.FloatTensor, ymin: torch.FloatTensor, ymax: torch.FloatTensor, ): """ Compute and store static map embeddings. A CNN extracts local features from the rasterized map and combines them with absolute positional encoding. To reduce computation, latent queries compress the dense map features through cross-attention into a compact environmental context representation. Args: map (torch.FloatTensor): Rasterized environment map, where `0` denotes walkable area and `1` denotes obstacles. Shape: `(Map_W, Map_H)`, with the first dimension as `x` pointing right and the second as `y` pointing up. xmin (float): Minimum x value of the map in world coordinates. xmax (float): Maximum x value of the map in world coordinates. ymin (float): Minimum y value of the map in world coordinates. ymax (float): Maximum y value of the map in world coordinates. Side Effects: Sets `self.map_embedding`: dense grid-map features. Sets `self.ltn_embedding`: compressed latent map features. Caches map boundary metadata such as `self.xmin` and `self.xmax`. """ map_embedding = self.map_embedder(map.unsqueeze(-1)) # (W', H', model_dim) xx = torch.linspace(xmin, xmax, map_embedding.size(0), device=map_embedding.device) yy = torch.linspace(ymin, ymax, map_embedding.size(1), device=map_embedding.device) gridx, gridy = torch.meshgrid(xx, yy, indexing='ij') gridxy = torch.stack([gridx, gridy], dim=-1) # (W', H', 2) map_embedding = map_embedding + self.positional_encoding(gridxy) # (W', H', model_dim) latent_tokens = self.latent_tokens if self.args.use_spatial_anchor: anchor_phys_x = xmin + self.anchor_norm[..., 0] * (xmax - xmin) anchor_phys_y = ymin + self.anchor_norm[..., 1] * (ymax - ymin) anchor_phys = torch.stack([anchor_phys_x, anchor_phys_y], dim=-1) # (S, S, 2) anchor_pe = self.positional_encoding(anchor_phys) # (S, S, model_dim) latent_tokens = latent_tokens + anchor_pe.flatten(0, 1) # (S, S, D) -> (K, D) if self.args.use_latent_query: ltn_embedding = self.latent_attntn(latent_tokens, map_embedding.flatten(0, 1)) # (#latent_token, model_dim) else: ltn_embedding = map_embedding.flatten(0, 1) self.map_embedding = map_embedding self.ltn_embedding = ltn_embedding self.map = map self.xmax = xmax self.xmin = xmin self.ymax = ymax self.ymin = ymin
[docs] def set_sur_info(self): """ Extract local environment features at each pedestrian's current position. Maps pedestrian world coordinates in `self.pos` to raster-map indices and retrieves the corresponding vectors from the dense map embedding. Side Effects: Sets `self.sur_info`: environmental features under each pedestrian. """ pos = self.pos xmax, xmin = self.xmax, self.xmin ymax, ymin = self.ymax, self.ymin map_embedding = self.map_embedding idx = pos[..., 0].detach().sub(xmin).div(xmax-xmin).mul(map_embedding.size(0)).round().long().clamp(0, map_embedding.size(0) - 1) # (batch_size, #pedestrian) jdx = pos[..., 1].detach().sub(ymin).div(ymax-ymin).mul(map_embedding.size(1)).round().long().clamp(0, map_embedding.size(1) - 1) # (batch_size, #pedestrian) sur_info = map_embedding[idx, jdx] # (batch_size, #pedestrian, model_dim) sur_info = F.layer_norm(sur_info, sur_info.shape[-1:]) self.sur_info = sur_info
[docs] def forward( self, denoise_t: torch.LongTensor, noisy_acc: torch.FloatTensor, ped_length: torch.LongTensor, veh_length: torch.LongTensor, timer: NamedTimer = None, ): """ Forward pass: predict denoised trajectories from noisy inputs. This method must be called after the `set_*_embedding` methods. It fuses the following information through Transformer decoder layers: 1. Diffusion timestep embedding (`t`) 2. Current noisy acceleration embedding (`x_t`) 3. Social interaction (`Ped-Ped Attention`) 4. Pedestrian-vehicle interaction (`Ped-Veh Attention`) 5. Environment interaction (`Ped-Map Attention`) Args: denoise_t (torch.LongTensor): Current diffusion timestep `t`. Shape: (batch_size,) noisy_acc (torch.FloatTensor): Noisy future acceleration sequence, the diffusion input `x_t`. Shape: (batch_size, num_peds, pred_step, 2) ped_length (torch.LongTensor): Number of valid pedestrians per sample in the batch, used for masking. Shape: (batch_size,) veh_length (torch.LongTensor): Number of valid vehicles per sample in the batch, used for masking. Shape: (batch_size,) timer (NamedTimer, optional): Timer object for performance profiling. Returns: torch.FloatTensor: Model prediction. If `args.predict_noise` is `True`, this is the predicted noise `epsilon`; otherwise it is the predicted original signal `x_0` in acceleration space. Shape: (batch_size, num_peds, pred_step, 2) """ # Embedding Pedestrian ped_embedding = self.ped_embedding denoise_t_embedding = self.denoise_t_embedder(denoise_t) # (batch_size, model_dim) denoise_t_embedding = denoise_t_embedding.unsqueeze(1) # (batch_size, 1, model_dim) noisy_acc_embedding = self.noisy_acc_embedder(noisy_acc) # (batch_size, #pedestrian, model_dim) ped_embedding = self.ped_encoder(ped_embedding + denoise_t_embedding + noisy_acc_embedding) # (batch_size, #pedestrian, model_dim) # ped_embedding = ped_embedding + denoise_t_embedding + noisy_acc_embedding # (batch_size, #pedestrian, model_dim) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Embedding Pedestrian') # Embedding Vehicle veh_embedding = self.veh_embedding # (batch_size, #vehicle, model_dim) # Embedding Map ltn_embedding = self.ltn_embedding.unsqueeze(0).expand(ped_embedding.size(0), *self.ltn_embedding.shape) # (batch_size, #latent_token, model_dim) # Build Mask batch_size = ped_embedding.size(0) max_ped_num = ped_embedding.size(1) max_veh_num = veh_embedding.size(1) ped_mask = torch.arange(max_ped_num, device=ped_length.device).unsqueeze(0).expand(batch_size, max_ped_num) # (batch_size, max_ped_num) ped_mask = ped_mask >= ped_length.unsqueeze(1) # (batch_size, max_ped_num) if (batch_wo_ped := ped_mask.all(dim=1)).any(): raise ValueError("Some batch samples have zero pedestrians, which is not allowed.") ped_mask[batch_wo_ped, 0] = False veh_mask = torch.arange(max_veh_num, device=veh_length.device).unsqueeze(0).expand(batch_size, max_veh_num) # (batch_size, max_veh_num) veh_mask = veh_mask >= veh_length.unsqueeze(1) # (batch_size, max_veh_num) if (batch_wo_veh := veh_mask.all(dim=1)).any(): veh_mask[batch_wo_veh, 0] = False if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Build Mask') # Social Attention ped_info = self.ped_attention( ped_embedding, ped_embedding, memory_key_padding_mask=ped_mask, tgt_key_padding_mask=ped_mask, ) # (batch_size, #pedestrian, model_dim) ped_info = F.layer_norm(ped_info, ped_info.shape[-1:]) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Social Attention') # ped_info = 0 # Vehicle Attention veh_info = self.veh_attention( ped_embedding, veh_embedding, memory_key_padding_mask=veh_mask, tgt_key_padding_mask=ped_mask, ) # (batch_size, #pedestrian, model_dim) veh_info = F.layer_norm(veh_info, veh_info.shape[-1:]) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Vehicle Attention') # veh_info = 0 # Map Attention pe = self.pe map_info = self.map_attention( ped_embedding + pe, ltn_embedding, tgt_key_padding_mask=ped_mask, ) # (batch_size, #pedestrian, model_dim) map_info = F.layer_norm(map_info, map_info.shape[-1:]) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Map Attention') # Surrounding Info sur_info = self.sur_info # Fusion ped_embedding = self.fusion_fc( ped_embedding + ped_info + veh_info + map_info + sur_info + denoise_t_embedding ) # (batch_size, #pedestrian, model_dim) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Fusion') # Output output = self.output_fc(ped_embedding) # (batch_size, #pedestrian, pred_step*2) output = output.view(*output.shape[:-1], self.args.pred_step, 2) # (batch_size, #pedestrian, pred_step, 2) if timer: torch.cuda.synchronize(device=self.args.device) timer.add('Output') return output