import torch
import logging
import numpy as np
import pandas as pd
import torch.nn.functional as F
from typing import List, Tuple
from argparse import Namespace
from dataclasses import dataclass
from ..model.model import Model
from ..diffusion import DDPM
from ..dataset import BaseDataset
from ..utils.get_force_map import get_force_map
from ..utils.extract_patches import extract_patches_torch
from .guidance import guidance
_logger = logging.getLogger('src.simulate')
[docs]
@dataclass
class SimulateState:
df_ped: pd.DataFrame # Pedestrian dataframe.
df_veh: pd.DataFrame # Vehicle dataframe.
map_data: object # Map data object.
ped_list: List[int] # Pedestrian ID list (#pedestrian,)
veh_list: List[int] # Vehicle ID list (#vehicle,)
frame: int # Current frame.
pos_now: torch.Tensor # Current pedestrian positions (Batch, #pedestrian, 2)
vel_now: torch.Tensor # Current pedestrian velocities (Batch, #pedestrian, 2)
hst_now: torch.Tensor # Current pedestrian history (Batch, #pedestrian, hist_step, 2)
des_now: torch.Tensor # Current pedestrian destinations (Batch, #pedestrian, 2)
spd_now: torch.Tensor # Current pedestrian desired speeds (Batch, #pedestrian, 1)
veh_now: torch.Tensor # Current vehicle positions (Batch, #vehicle, hist_step + 1, 2)
[docs]
def init_simulation(
args: Namespace, dataset: BaseDataset, frame_idx: int, model: Model
) -> SimulateState:
"""
Initialize the state required for simulation.
Args:
args: Configuration arguments.
dataset: Dataset object.
frame_idx: Initial frame index.
model: Model object.
Returns:
SimulateState: Initialized simulation state.
"""
df_data = dataset.df_data.set_index(['f', 'id']).sort_index()
df_ped = df_data.loc[df_data['type'] == 'pedestrian', ['x', 'y']]
df_veh = df_data.loc[df_data['type'] == 'vehicle', ['x', 'y']]
ped_list = df_ped.loc[frame_idx].index.tolist() if frame_idx in df_ped.index else []
veh_list = df_veh.loc[frame_idx].index.tolist() if frame_idx in df_veh.index else []
pos = (
df_ped
.reindex(pd.MultiIndex.from_product([
[frame_idx],
ped_list
], names=['f', 'id']))
# .fillna(0.0) # There should be no NaN here.
.values.reshape(len(ped_list), 2) # (#pedestrian, 2)
)
assert pos.shape == (len(ped_list), 2)
vel = (
df_ped
.reindex(pd.MultiIndex.from_product([
[frame_idx-1, frame_idx],
ped_list
], names=['f', 'id']))
.unstack()
.diff().mul(args.fps).iloc[1]
.unstack().T
.fillna(0.0)
.values # (#pedestrian, 2)
)
assert vel.shape == (len(ped_list), 2)
hst = (
df_ped
.reindex(pd.MultiIndex.from_product([
range(frame_idx-args.hist_step, frame_idx),
ped_list
], names=['f', 'id']))
.values.reshape(args.hist_step, len(ped_list), 2) # (hist_step, #pedestrian, 2)
.transpose(1, 0, 2) # (#pedestrian, hist_step, 2)
)
assert hst.shape == (len(ped_list), args.hist_step, 2)
veh = (
df_veh
.reindex(pd.MultiIndex.from_product([
range(frame_idx-args.hist_step, frame_idx + 1),
veh_list
], names=['f', 'id']))
.values.reshape(args.hist_step+1, len(veh_list), 2) # (#vehicle, hist_step + 1, 2)
.transpose(1, 0, 2) # (#vehicle, hist_step + 1, 2)
)
assert veh.shape == (len(veh_list), args.hist_step + 1, 2)
des = (
df_ped
.loc[df_ped.index.get_level_values('id').isin(ped_list)]
.groupby(level=1, sort=False).tail(1)
.swaplevel(axis=0).reindex(index=ped_list, level=0)
.values # (#pedestrian, 2)
)
assert des.shape == (len(ped_list), 2)
# Override destinations with user-defined values if present.
# `user_destinations` uses string keys, so `ped_id` must be converted before lookup.
if hasattr(dataset, 'user_destinations') and dataset.user_destinations:
for idx, ped_id in enumerate(ped_list):
ped_id_str = str(ped_id)
if ped_id_str in dataset.user_destinations:
user_des = dataset.user_destinations[ped_id_str]
des[idx, 0] = user_des['x']
des[idx, 1] = user_des['y']
_logger.info(f"Using user-defined destination for pedestrian {ped_id}: ({user_des['x']}, {user_des['y']})")
# Compute desired speed `spd`.
# Simulated datasets may be truncated and therefore lack enough future frames.
# In that case, use the current speed as an estimate.
spd_data_range = range(frame_idx, frame_idx + int(5 * args.fps) + 1)
available_frames = df_ped.index.get_level_values('f').unique()
has_future_data = any(f in available_frames for f in spd_data_range if f > frame_idx)
if has_future_data:
# Enough future frames are available, compute normally.
spd = (
df_ped
.reindex(pd.MultiIndex.from_product([
spd_data_range,
ped_list
], names=['f', 'id']))
.unstack().ffill().bfill().diff().mul(args.fps).iloc[1:]
.stack(future_stack=True)
.pow(2).sum(axis='columns', min_count=2).pow(0.5)
.unstack().mean(axis='rows')
.values[:, np.newaxis]
)
else:
# The simulated dataset is truncated, so use the current speed as the desired speed estimate.
_logger.info(f"No future data for spd calculation, using current velocity as estimated desired speed")
spd = np.linalg.norm(vel, axis=1, keepdims=True)
# Use a reasonable default desired speed if the current speed is too small.
default_spd = 1.0
spd = np.where(spd < default_spd, default_spd, spd)
assert spd.shape == (len(ped_list), 1), "spd shape mismatch"
map_data = dataset.map_data
if args.no_destination:
des *= np.nan
if args.no_speed:
spd *= np.nan
## Simulation
S = args.sample_num # Number of samples.
N = args.denoise_step # Number of denoising steps.
assert args.T % N == 0, f"Requested {N} sampling steps, but training steps {args.T} mod {N} is not 0."
assert 1 <= args.step_offset <= args.T // N, f"`step_offset` must lie in {{1, ..., {args.T // N}}}."
pos_now = torch.from_numpy(pos).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1) # (S*1, #pedestrian, 2)
vel_now = torch.from_numpy(vel).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1) # (S*1, #pedestrian, 2)
hst_now = torch.from_numpy(hst).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1, 1) # (S*1, #pedestrian, hist_step, 2)
des_now = torch.from_numpy(des).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1) # (S*1, #pedestrian, 2)
spd_now = torch.from_numpy(spd).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1) # (S*1, #pedestrian, 1)
veh_now = torch.from_numpy(veh).to(device=args.device, dtype=torch.float32)[None, ...].repeat(S, 1, 1, 1) # (S*1, #vehicle, hist_step + 1, 2)
model.set_map_embedding(
map=torch.from_numpy(map_data.map).to(device=args.device, dtype=torch.float32),
xmin=map_data.xmin,
xmax=map_data.xmax,
ymin=map_data.ymin,
ymax=map_data.ymax,
)
frame = frame_idx
return SimulateState(
df_ped=df_ped,
df_veh=df_veh,
map_data=map_data,
ped_list=ped_list,
veh_list=veh_list,
frame=frame,
pos_now=pos_now,
vel_now=vel_now,
hst_now=hst_now,
des_now=des_now,
spd_now=spd_now,
veh_now=veh_now,
)
[docs]
def simulate_one_step(
args: Namespace, model: Model, diffusion: DDPM, state: SimulateState
) -> Tuple[pd.DataFrame, SimulateState]:
model.eval()
torch.set_grad_enabled(False)
# _logger.info(f"Simulating from frame {frame} to frame {frame + args.pred_step}...")
S = args.sample_num # Number of samples.
N = args.denoise_step # Number of denoising steps.
ped_length_repeat = torch.full((S, ), len(state.ped_list), device=args.device, dtype=torch.long) # (S,)
veh_length_repeat = torch.full((S, ), len(state.veh_list), device=args.device, dtype=torch.long) # (S,)
# _logger.info(f" Starting setting vehicle embeddings...")
model.set_veh_embedding(veh=state.veh_now)
# _logger.info(f" Starting setting pedestrian embeddings...")
model.set_ped_embedding(pos=state.pos_now, vel=state.vel_now, hst=state.hst_now, des=state.des_now, spd=state.spd_now)
# _logger.info(f" Starting setting surrounding info...")
model.set_sur_info()
shape = [S, len(state.ped_list), args.pred_step, 2] # (S*1, #pedestrian, pred_step, 2)
xt = torch.randn(shape, device=args.device) # Start from Gaussian noise.
stride = args.T // N
for t in reversed(range(args.step_offset, args.T+1, stride)):
# _logger.info(f" Denoising step at t={t}...")
noisy_acc = xt
denoise_t = torch.full((xt.shape[0],), t, device=args.device, dtype=torch.long)
output = model(
noisy_acc=noisy_acc,
denoise_t=denoise_t,
ped_length=ped_length_repeat,
veh_length=veh_length_repeat,
) # (S*B, #pedestrian, pred_step, 2)
## Recover the estimated x0.
x0 = diffusion.noise_to_x0(xt, denoise_t, output) if args.predict_noise else output
## Apply different guidance terms. The coefficients are typically around 0~1.
x0 = x0 + guidance(args, x0, state, model, diffusion, noisy_acc, xt, denoise_t, ped_length_repeat, veh_length_repeat)
## Denoise.
xt = diffusion.denoise(xt, t, x0=x0, stride=min(stride, t))
acc_new = xt / args.scale_accelerate
if args.use_sfm:
future_vel = state.vel_now.unsqueeze(-2)
future_pos = state.pos_now.unsqueeze(-2)
# Destination-driven force.
desire_vel = F.normalize(state.des_now.unsqueeze(-2) - future_pos, dim=-1) * state.spd_now.unsqueeze(-2) # (S*B, #pedestrian, pred_step, 2)
des_force = (desire_vel - future_vel).nan_to_num(0.0) / args.sfm_t_des # (S*B, #pedestrian, pred_step, 2)
# Obstacle repulsion from the scene map.
F_map = get_force_map(r=args.sfm_r_map, A=args.sfm_a_map, B=args.sfm_b_map, device=args.device) # (2r+1, 2r+1, 2)
idx = future_pos[..., 0].sub(model.xmin).div(model.xmax - model.xmin).mul(model.map.shape[0]).round().long().clamp(0, model.map.shape[0] - 1) # (S*B, #pedestrian, pred_step)
jdx = future_pos[..., 1].sub(model.ymin).div(model.ymax - model.ymin).mul(model.map.shape[1]).round().long().clamp(0, model.map.shape[1] - 1) # (S*B, #pedestrian, pred_step)
patches = extract_patches_torch(model.map, idx.reshape(-1), jdx.reshape(-1), r=args.sfm_r_map).reshape(*idx.shape, 2*args.sfm_r_map+1, 2*args.sfm_r_map+1) # (S*B, #pedestrian, pred_step, 2r+1, 2r+1)
map_force = (patches[..., None] * F_map).nan_to_num(0.0).flatten(-3, -2).sum(-2) # (S*B, #pedestrian, pred_step, 2)
# Repulsion from other pedestrians.
p = future_pos[:, None, :, :, :] - future_pos[:, :, None, :, :] # (S*B, #focal-pedestrian, #other-pedestrian, pred_step, 2)
d = torch.norm(p, dim=-1, keepdim=True)
n = -p / d.clamp(min=1e-6)
F_ped = args.sfm_a_ped * torch.exp(-d / args.sfm_b_ped) * n
ped_force = F_ped.nan_to_num(0.0).sum(dim=2) # (S*B, #pedestrian, pred_step, 2)
# Repulsion from vehicles, using their last-frame positions.
p = state.veh_now[:, :, None, -1:, :] - future_pos[:, None, :, :, :] # (S*B, #vehicle, #pedestrian, pred_step, 2)
d = torch.norm(p, dim=-1, keepdim=True)
n = -p / d.clamp(min=1e-6)
F_veh = args.sfm_a_veh * torch.exp(-d / args.sfm_b_veh) * n
veh_force = F_veh.nan_to_num(0.0).sum(dim=1) # (S*B, #pedestrian, pred_step, 2)
# Damping force.
damp_force = -args.sfm_a_damp * future_vel # (S*B, #pedestrian, pred_step, 2)
# Total force.
acc_new = des_force + map_force + ped_force + veh_force + damp_force
vel_new = state.vel_now.unsqueeze(-2) + acc_new.cumsum(dim=-2) / args.fps # (S*B, #pedestrian, pred_step, 2)
# Set the velocity of pedestrians who have reached their destination to 0.
if args.threshold_of_arrive > 0:
_pos_new = state.pos_now.unsqueeze(-2) + vel_new.cumsum(dim=-2) / args.fps # (S*B, #pedestrian, pred_step, 2)
arrived = (_pos_new - state.des_now[:, :, None, :]).norm(dim=-1) < args.threshold_of_arrive # (S*B, #pedestrian, pred_step)
arrived = arrived.cumsum(dim=-1) > 0 # (S*B, #pedestrian, pred_step)
vel_new[arrived, :] = 0.0
pos_new = state.pos_now.unsqueeze(-2) + vel_new.cumsum(dim=-2) / args.fps # (S*B, #pedestrian, pred_step, 2)
if state.veh_list:
veh_new = torch.from_numpy(
state.df_veh
.loc[state.frame+1:state.frame+args.pred_step] # Pandas slicing is inclusive, so this returns exactly `pred_step` frames.
.unstack().swaplevel(axis='columns').sort_index(axis='columns')
.reindex(
index=range(state.frame+1, state.frame+args.pred_step+1),
columns=pd.MultiIndex.from_product([state.veh_list, ['x', 'y']])
)
.values.reshape(args.pred_step, len(state.veh_list), 2) # (pred_step, #vehicle, 2)
.transpose(1, 0, 2) # (#vehicle, pred_step, 2)
).to(device=args.device, dtype=torch.float32)
des_new = state.des_now # (S*B, #pedestrian, 2)
spd_new = state.spd_now # (S*B, #pedestrian, 1)
# _logger.info(f" Computed new positions and velocities.")
state.hst_now = torch.cat([state.hst_now, state.pos_now.unsqueeze(-2), pos_new], dim=-2)[:, :, -args.hist_step-1:-1, :] # (S*B, #pedestrian, hist_step, 2)
if state.veh_list:
state.veh_now = torch.cat([state.veh_now, veh_new.unsqueeze(0).repeat(S, 1, 1, 1)], dim=-2)[:, :, -args.hist_step-1:, :] # (S*B, #vehicle, hist_step + 1, 2)
state.pos_now = pos_new[:, :, -1, :] # (S*B, #pedestrian, 2)
state.vel_now = vel_new[:, :, -1, :] # (S*B, #pedestrian, 2)
state.spd_now = spd_new # (S*B, #pedestrian, 1)
state.des_now = des_new # (S*B, #pedestrian, 2)
# _logger.info(f" Updated states for next step.")
# _logger.info(f" Converting new positions to CPU numpy. {pos_new.shape}, {type(pos_new)}")
df_ped_new = pd.DataFrame([
{
'f': state.frame + 1 + f,
'id': state.ped_list[i],
'type': 'pedestrian',
'x': float(pos_new[s, i, f, 0]),
'y': float(pos_new[s, i, f, 1]),
'sample': s,
}
for i in range(pos_new.shape[1])
for f in range(pos_new.shape[2])
for s in range(pos_new.shape[0])
])
# _logger.info(f" Converted new positions to CPU numpy. {pos_new.shape}, {type(pos_new)}")
df_veh_orig = state.df_veh.loc[state.frame+1:state.frame+args.pred_step].reset_index().assign(type='vehicle')
df_veh_new = pd.concat([
df_veh_orig.assign(sample=s) for s in range(pos_new.shape[0])
], ignore_index=True)
df_new = pd.concat([df_ped_new, df_veh_new], ignore_index=True)
df_new['sample'] = df_new['sample'].astype(int)
state.frame = state.frame + args.pred_step
# _logger.info(f"Completed simulation up to frame {frame}.")
return df_new, state