Source code for src.tasks.test_once

import time
import torch
import logging
import numpy as np
import torch.nn as nn
import torch.utils.data as D
from tqdm import tqdm
from typing import List, Dict
from ..model import Model
from ..diffusion import DDPM
from ..utils.timer import NamedTimer
from ..utils.tag2ansi import tag2ansi
from ..utils.calc_xy_error import calc_xy_error
from .visualize import visualize
from .guidance import guidance
from .simulate import SimulateState

_logger = logging.getLogger(__name__)


[docs] def get_xy_error(args, pos, veh, mask, pos_true, pos_pred, sample_idx, valid_idx): traj_diff = (pos_pred - pos_true)[:, mask, :, :][sample_idx, valid_idx, :, :] # (valid{B*#pedestrian}, roll_step*pred_step, 2) ped_pos = pos[mask, :] # (valid{B*#pedestrian}, 2) batch_indices = torch.nonzero(mask)[:, 0] # Batch indices of valid pedestrians (valid{B*#pedestrian},) num_valid_ped = batch_indices.shape[0] if veh.shape[1] == 0: # The scene contains no vehicle data at all. veh_pos = torch.full((num_valid_ped, 2), float('nan'), device=pos.device) veh_vel = torch.full((num_valid_ped, 2), float('nan'), device=pos.device) else: # Find the nearest vehicle for each pedestrian in each scene. all_veh_pos = veh[..., -1, :] # (batch_size, #vehicle, 2) all_veh_vel = (veh[..., -1, :] - veh[..., -2, :]) * args.fps # (batch_size, #vehicle, 2) # Gather vehicle data from the scene corresponding to each valid pedestrian. batch_veh_pos = all_veh_pos[batch_indices] # (valid{B*#pedestrian}, #vehicle, 2) batch_veh_vel = all_veh_vel[batch_indices] # (valid{B*#pedestrian}, #vehicle, 2) # Compute distances to all vehicles in the same scene; invalid vehicles get infinite distance. dist = (ped_pos.unsqueeze(1) - batch_veh_pos).norm(dim=-1).nan_to_num_(nan=float('inf')) # (valid{B*#pedestrian}, #vehicle) # Find the nearest vehicle index. min_dist, nearest_idx = torch.min(dist, dim=1) # (valid{B*#pedestrian},) has_vehicle = min_dist != float('inf') # Gather the nearest vehicle position and velocity. gather_idx = nearest_idx.view(-1, 1, 1).expand(-1, 1, 2) veh_pos = torch.gather(batch_veh_pos, 1, gather_idx).squeeze(1) # (valid{B*#pedestrian}, 2) veh_vel = torch.gather(batch_veh_vel, 1, gather_idx).squeeze(1) # (valid{B*#pedestrian}, 2) # If a pedestrian's scene has no vehicle at all, mark it as NaN. veh_pos[~has_vehicle] = float('nan') veh_vel[~has_vehicle] = float('nan') norm_err, tan_err = calc_xy_error(traj_diff, ped_pos, veh_pos, veh_vel) return norm_err, tan_err
[docs] def get_collision_rate(args, map_data, future_veh, veh_length, mask, pos_pred): S, B, P, T, _ = pos_pred.shape flat_pos = pos_pred.permute(0, 1, 3, 2, 4).reshape(-1, P, 2) ## Reshape (S, B, P, T, 2) -> (S, B, T, P, 2) -> (S*B*T, P, 2) dist_matrix = torch.cdist(flat_pos, flat_pos, p=2) # (S*B*T, P, P) eye_matrix = torch.eye(P, device=pos_pred.device, dtype=torch.bool).unsqueeze(0) # Count a collision only when both pedestrian i and pedestrian j are valid. mask_expanded = mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, P) # (B, P) -> (1, B, 1, P) -> (S, B, T, P) -> (S*B*T, P) valid_pair_mask = mask_expanded.unsqueeze(2) & mask_expanded.unsqueeze(1) collision_matrix = ( (dist_matrix < args.collision_threshold) & (~eye_matrix) & valid_pair_mask ) collision_rate = collision_matrix.sum() / (S * mask.sum() * T) collision_ped = collision_rate.item() if future_veh.shape[1] == 0: collision_veh = float('nan') else: _, V, _, _ = future_veh.shape # (B, V, T, 2) flat_veh = future_veh.unsqueeze(0).expand(S, -1, -1, -1, -1).permute(0, 1, 3, 2, 4).reshape(-1, V, 2) ## Reshape (B, V, T, 2) -> (1, B, V, T, 2) -> (S, B, T, V, 2) -> (S*B*T, V, 2) dist_matrix = torch.cdist(flat_pos, flat_veh, p=2) # (S*B*T, P, V) # Count a collision only when pedestrian i and vehicle j are both valid. veh_mask = torch.arange(V, device=args.device).expand(B, V) < veh_length.unsqueeze(-1) # (B, V) veh_mask_expanded = veh_mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, V) # (B, V) -> (1, B, 1, V) -> (S, B, T, V) -> (S*B*T, V) valid_pair_mask = mask_expanded.unsqueeze(2) & veh_mask_expanded.unsqueeze(1) # (S*B*T, P, V) collision_matrix = (dist_matrix < args.collision_threshold) & valid_pair_mask collision_rate = collision_matrix.sum() / 2 / (S * mask.sum() * T) # Divide by 2 because only one side of the collision is a pedestrian. collision_veh = collision_rate.item() if not np.isfinite(map_data.map).any(): collision_map = float('nan') else: idx = pos_pred[..., 0].sub(map_data.xmin).div(map_data.xmax-map_data.xmin).mul(map_data.map.shape[0]).round().long().clamp(0, map_data.map.shape[0] - 1) # (batch_size, #pedestrian) jdx = pos_pred[..., 1].sub(map_data.ymin).div(map_data.ymax-map_data.ymin).mul(map_data.map.shape[1]).round().long().clamp(0, map_data.map.shape[1] - 1) # (batch_size, #pedestrian) sur_info = map_data.map[idx.cpu().numpy(), jdx.cpu().numpy()] # (batch_size, #pedestrian) collision_rate = (sur_info > 0.9).mean() collision_map = collision_rate.item() return collision_ped, collision_veh, collision_map
[docs] def get_collision_rate2(args, map_data, future_veh, veh_length, mask, pos_pred, pos_true): S, B, P, T, _ = pos_true.unsqueeze(0).shape flat_pos_true = pos_true.unsqueeze(0).permute(0, 1, 3, 2, 4).reshape(-1, P, 2) ## Reshape (S, B, P, T, 2) -> (S, B, T, P, 2) -> (S*B*T, P, 2) dist_matrix = torch.cdist(flat_pos_true, flat_pos_true, p=2) # (S*B*T, P, P) eye_matrix = torch.eye(P, device=pos_true.unsqueeze(0).device, dtype=torch.bool).unsqueeze(0) # Count a collision only when both pedestrian i and pedestrian j are valid. mask_expanded_true = mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, P) # (B, P) -> (1, B, 1, P) -> (S, B, T, P) -> (S*B*T, P) valid_pair_mask = mask_expanded_true.unsqueeze(2) & mask_expanded_true.unsqueeze(1) collision_matrix_true = ( (dist_matrix < args.collision_threshold) & (~eye_matrix) & valid_pair_mask ) S, B, P, T, _ = pos_pred.shape flat_pos = pos_pred.permute(0, 1, 3, 2, 4).reshape(-1, P, 2) ## Reshape (S, B, P, T, 2) -> (S, B, T, P, 2) -> (S*B*T, P, 2) dist_matrix = torch.cdist(flat_pos, flat_pos, p=2) # (S*B*T, P, P) eye_matrix = torch.eye(P, device=pos_pred.device, dtype=torch.bool).unsqueeze(0) # Count a collision only when both pedestrian i and pedestrian j are valid. mask_expanded = mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, P) # (B, P) -> (1, B, 1, P) -> (S, B, T, P) -> (S*B*T, P) valid_pair_mask = mask_expanded.unsqueeze(2) & mask_expanded.unsqueeze(1) collision_matrix = ( (dist_matrix < args.collision_threshold) & (~eye_matrix) & valid_pair_mask ) collision_matrix = collision_matrix.reshape(S, B, T, P, P) & (~collision_matrix_true).reshape(1, B, T, P, P) # Count only collisions that appear in prediction but not in ground truth. collision_rate = collision_matrix.sum() / (S * mask.sum() * T) collision_ped = collision_rate.item() if future_veh.shape[1] == 0: collision_veh = float('nan') else: S, B, P, T, _ = pos_true.unsqueeze(0).shape _, V, _, _ = future_veh.shape # (B, V, T, 2) flat_veh = future_veh.unsqueeze(0).expand(S, -1, -1, -1, -1).permute(0, 1, 3, 2, 4).reshape(-1, V, 2) ## Reshape (B, V, T, 2) -> (1, B, V, T, 2) -> (S, B, T, V, 2) -> (S*B*T, V, 2) dist_matrix = torch.cdist(flat_pos_true, flat_veh, p=2) # (S*B*T, P, V) # Count a collision only when pedestrian i and vehicle j are both valid. veh_mask = torch.arange(V, device=args.device).expand(B, V) < veh_length.unsqueeze(-1) # (B, V) veh_mask_expanded = veh_mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, V) # (B, V) -> (1, B, 1, V) -> (S, B, T, V) -> (S*B*T, V) valid_pair_mask = mask_expanded_true.unsqueeze(2) & veh_mask_expanded.unsqueeze(1) # (S*B*T, P, V) collision_matrix_true = (dist_matrix < args.collision_threshold) & valid_pair_mask S, B, P, T, _ = pos_pred.shape _, V, _, _ = future_veh.shape # (B, V, T, 2) flat_veh = future_veh.unsqueeze(0).expand(S, -1, -1, -1, -1).permute(0, 1, 3, 2, 4).reshape(-1, V, 2) ## Reshape (B, V, T, 2) -> (1, B, V, T, 2) -> (S, B, T, V, 2) -> (S*B*T, V, 2) dist_matrix = torch.cdist(flat_pos, flat_veh, p=2) # (S*B*T, P, V) # Count a collision only when pedestrian i and vehicle j are both valid. veh_mask = torch.arange(V, device=args.device).expand(B, V) < veh_length.unsqueeze(-1) # (B, V) veh_mask_expanded = veh_mask.unsqueeze(0).unsqueeze(2).expand(S, -1, T, -1).reshape(-1, V) # (B, V) -> (1, B, 1, V) -> (S, B, T, V) -> (S*B*T, V) valid_pair_mask = mask_expanded.unsqueeze(2) & veh_mask_expanded.unsqueeze(1) # (S*B*T, P, V) collision_matrix = (dist_matrix < args.collision_threshold) & valid_pair_mask collision_matrix = collision_matrix.reshape(S, B, T, P, V) & (~collision_matrix_true).reshape(1, B, T, P, V) # Count only collisions that appear in prediction but not in ground truth. collision_rate = collision_matrix.sum() / 2 / (S * mask.sum() * T) # Divide by 2 because only one side of the collision is a pedestrian. collision_veh = collision_rate.item() if not np.isfinite(map_data.map).any(): collision_map = float('nan') else: idx = pos_true[None, ..., 0].sub(map_data.xmin).div(map_data.xmax-map_data.xmin).mul(map_data.map.shape[0]).round().long().clamp(0, map_data.map.shape[0] - 1) # (S, B, P, T) jdx = pos_true[None, ..., 1].sub(map_data.ymin).div(map_data.ymax-map_data.ymin).mul(map_data.map.shape[1]).round().long().clamp(0, map_data.map.shape[1] - 1) # (S, B, P, T) sur_info = map_data.map[idx.cpu().numpy(), jdx.cpu().numpy()] # (S, B, P, T) collision_matrix_true = sur_info > 0.9 idx = pos_pred[..., 0].sub(map_data.xmin).div(map_data.xmax-map_data.xmin).mul(map_data.map.shape[0]).round().long().clamp(0, map_data.map.shape[0] - 1) # (S, B, P, T) jdx = pos_pred[..., 1].sub(map_data.ymin).div(map_data.ymax-map_data.ymin).mul(map_data.map.shape[1]).round().long().clamp(0, map_data.map.shape[1] - 1) # (S, B, P, T) sur_info = map_data.map[idx.cpu().numpy(), jdx.cpu().numpy()] # (S, B, P, T) collision_matrix = (sur_info > 0.9).reshape(S, B, P, T) & (~collision_matrix_true).reshape(1, B, P, T) # Count only collisions that appear in prediction but not in ground truth. collision_rate = collision_matrix.sum() / (S * mask.sum() * T) collision_map = collision_rate.item() return collision_ped, collision_veh, collision_map
[docs] def test_once( args, test_loaders: List[D.DataLoader], model: Model, criterion: nn.Module, diffusion: DDPM, epoch: int, ) -> Dict: """ Run one evaluation epoch over the provided loaders. Args: args: Global arguments. test_loaders: Test dataloaders. model: Model under evaluation. criterion: Loss function. diffusion: Diffusion module. epoch: Current epoch index. Returns: all_records: Evaluation metrics dictionary. """ test_timer = NamedTimer(unit='it', mode='pace') records_list = [] for loader in test_loaders: map_data = loader.dataset.map_data map = torch.from_numpy(map_data.map).to(args.device).float() test_timer.add('prepare data') model.set_map_embedding( map=map, xmin=map_data.xmin, xmax=map_data.xmax, ymin=map_data.ymin, ymax=map_data.ymax, ) test_timer.add('embed map') records = dict( loss=[], ade=[], fde=[], trajlen=[], ped_num=[], veh_num=[], rollout_time=[], collision_ped=[], collision_veh=[], collision_map=[], collision_ped_base=[], collision_veh_base=[], collision_map_base=[], collision_ped2=[], collision_veh2=[], collision_map2=[], apd=[], norm_err=[], tan_err=[] ) for batch_idx, batch in enumerate(tqdm(loader, disable=False, leave=False, dynamic_ncols=True)): pos = batch['pos'].to(args.device) # (batch_size, #pedestrian, 2) vel = batch['vel'].to(args.device) # (batch_size, #pedestrian, 2) hst = batch['hst'].to(args.device) # (batch_size, #pedestrian, hist_step, 2) des = batch['des'].to(args.device) # (batch_size, #pedestrian, 2) spd = batch['spd'].to(args.device) # (batch_size, #pedestrian, 1) veh = batch['veh'].to(args.device) # (batch_size, #vehicle, hist_step + 1, 2) future_acc = batch['future_acc'].to(args.device) # (batch_size, #pedestrian, pred_step*roll_step, 2) future_pos = batch['future_pos'].to(args.device) # (batch_size, #pedestrian, pred_step*roll_step, 2) future_veh = batch['future_veh'].to(args.device) # (batch_size, #vehicle, pred_step*roll_step, 2) ped_length = batch['ped_length'].to(args.device) # (batch_size,) veh_length = batch['veh_length'].to(args.device) # (batch_size,) S = args.sample_num # Number of samples. N = args.denoise_step # Number of denoising steps. assert args.T % N == 0, f"Requested {N} sampling steps, but training steps {args.T} mod {N} is not 0." assert 1 <= args.step_offset <= args.T // N, f"`step_offset` must lie in {{1, ..., {args.T // N}}}." pos_now = pos.repeat(S, 1, 1) # (S*B, #pedestrian, 2) vel_now = vel.repeat(S, 1, 1) # (S*B, #pedestrian, 2) hst_now = hst.repeat(S, 1, 1, 1) # (S*B, #pedestrian, hist_step, 2) des_now = des.repeat(S, 1, 1) # (S*B, #pedestrian, 2) spd_now = spd.repeat(S, 1, 1) # (S*B, #pedestrian, 1) veh_now = veh.repeat(S, 1, 1, 1) # (S*B, #vehicle, hist_step + 1, 2) ped_length_repeat = ped_length.repeat(S) # (S*B,) veh_length_repeat = veh_length.repeat(S) # (S*B,) test_timer.add('prepare data', n=0) for_plot = [] acc_pred = [] start_time = time.time() for step in range(args.roll_step): if not args.cache_latent_query: model.set_map_embedding( map=map, xmin=map_data.xmin, xmax=map_data.xmax, ymin=map_data.ymin, ymax=map_data.ymax, ) model.set_veh_embedding(veh=veh_now) model.set_ped_embedding(pos=pos_now, vel=vel_now, hst=hst_now, des=des_now, spd=spd_now) model.set_sur_info() test_timer.add('embed data') shape = list(future_acc.shape) shape[0] *= S shape[2] = args.pred_step xt = torch.randn(shape, device=args.device) # Start from Gaussian noise. for_plot.append([diffusion.noise_to_x0(xt=xt, denoise_t=args.T, noise=0) / args.scale_accelerate]) stride = args.T // N steps = reversed(range(args.step_offset, args.T+1, stride)) for t in tqdm(steps, disable=True, leave=False, dynamic_ncols=True): noisy_acc = xt denoise_t = torch.full((xt.shape[0],), t, device=args.device, dtype=torch.long) output = model( noisy_acc=noisy_acc, denoise_t=denoise_t, ped_length=ped_length_repeat, veh_length=veh_length_repeat, ) # (S*B, #pedestrian, pred_step, 2) ## Recover the estimated x0. x0 = diffusion.noise_to_x0(xt, denoise_t, output) if args.predict_noise else output ## Apply different guidance terms. The coefficients are typically around 0~1. state = SimulateState( df_ped=None, df_veh=None, map_data=None, ped_list=None, veh_list=None, frame=None, pos_now=pos_now, vel_now=vel_now, des_now=des_now, hst_now=hst_now, spd_now=spd_now, veh_now=veh_now, ) x0 = x0 + guidance(args, x0, state, model, diffusion, noisy_acc, xt, denoise_t, ped_length_repeat, veh_length_repeat) ## Denoise. xt = diffusion.denoise(xt, t, x0=x0, stride=min(stride, t)) for_plot[-1].append(x0 / args.scale_accelerate) acc_new = xt / args.scale_accelerate # (S*B, #pedestrian, pred_step, 2) acc_pred.append(acc_new) test_timer.add('denoise') vel_new = vel_now.unsqueeze(-2) + acc_new.cumsum(dim=-2) / args.fps # (S*B, #pedestrian, pred_step, 2) pos_new = pos_now.unsqueeze(-2) + vel_new.cumsum(dim=-2) / args.fps # (S*B, #pedestrian, pred_step, 2) veh_new = future_veh[:, :, step*args.pred_step:(step+1)*args.pred_step, :].repeat(S, 1, 1, 1) # (S*B, #vehicle, pred_step, 2) hst_now = torch.cat([hst_now, pos_now.unsqueeze(-2), pos_new], dim=-2)[:, :, -args.hist_step-1:-1, :] # (S*B, #pedestrian, hist_step, 2) veh_now = torch.cat([veh_now, veh_new], dim=-2)[:, :, -args.hist_step-1:, :] # (S*B, #vehicle, hist_step + 1, 2) pos_now = pos_new[:, :, -1, :] # (S*B, #pedestrian, 2) vel_now = vel_new[:, :, -1, :] # (S*B, #pedestrian, 2) test_timer.add('rollout') rollout_time = (time.time() - start_time) / args.roll_step acc_pred = torch.concat(acc_pred, dim=-2) # (S*B, #pedestrian, roll_step*pred_step, 2) batch_size, ped_num, _, _ = future_acc.shape acc_pred = acc_pred.view(S, batch_size, ped_num, args.roll_step*args.pred_step, 2) # (S, B, #pedestrian, roll_step*pred_step, 2) # Get the valid-pedestrian mask. mask = torch.arange(ped_num, device=args.device).expand(batch_size, ped_num) < ped_length.unsqueeze(-1) # (B, #pedestrian) # Compute `pos_true` and `vel_true`. acc_true = future_acc # (B, #pedestrian, roll_step*pred_step, 2) vel_true = vel.unsqueeze(-2) + acc_true.cumsum(dim=-2) / args.fps # (B, #pedestrian, roll_step*pred_step, 2) pos_true = pos.unsqueeze(-2) + vel_true.cumsum(dim=-2) / args.fps # (B, #pedestrian, roll_step*pred_step, 2) max_err = np.nanmax((pos_true - future_pos).abs().cpu().numpy(), axis=(-2, -1)) _logger.debug( f"max gap between `pos_true` and `future` > 1: {(max_err > 1).mean():.2%}, " f"max gap between `pos_true` and `future` > 1e-6: {(max_err > 1e-6).mean():.2%}" ) # pos_true = future_pos # (B, #pedestrian, roll_step*pred_step, 2) # Compute loss. loss = criterion(acc_pred, acc_true.expand(acc_pred.shape)) # float records['loss'].extend([loss.item()] * future_acc.shape[0]) # List[float] # Compute distance error. vel_pred = vel.unsqueeze(-2) + acc_pred.cumsum(dim=-2) / args.fps # (S, B, #pedestrian, roll_step*pred_step, 2) pos_pred = pos.unsqueeze(-2) + vel_pred.cumsum(dim=-2) / args.fps # (S, B, #pedestrian, roll_step*pred_step, 2) dis_err = (pos_pred - pos_true).norm(dim=-1) # (S, B, #pedestrian, roll_step*pred_step) test_timer.add('evaluate') # Visualization. if False and batch_idx == 0: pid = 0 save_path = f"{args.save_path}/visualize/epoch{epoch}_{loader.dataset.name}_idx{batch_idx}_pid{pid}.png" visualize(args, pos, vel, hst, for_plot, mask, pos_true, pos_pred, save_path, pid) test_timer.add('visualize') # Remove padded pedestrians. dis_err = dis_err[:, mask, :] # (S, valid{B*#pedestrian}, roll_step*pred_step) # Select the sample with the best ADE. sample_idx = dis_err.mean(dim=-1).argmin(dim=0) # (valid{B*#pedestrian},) valid_idx = torch.arange(dis_err.shape[1], device=args.device) # (valid{B*#pedestrian},) dis_err = dis_err[sample_idx, valid_idx, :] # (valid{B*#pedestrian}, roll_step*pred_step) # Compute ADE and FDE. ade = dis_err.mean(dim=-1) # (valid{B*#pedestrian}) fde = dis_err[..., -1] # (valid{B*#pedestrian}) records['ade'].extend(ade.cpu().tolist()) # List[float] records['fde'].extend(fde.cpu().tolist()) # List[float] # Compute tangential and normal errors. norm_err, tan_err = get_xy_error(args, pos, veh, mask, pos_true, pos_pred, sample_idx, valid_idx) records['norm_err'].extend(norm_err.cpu().tolist()) # List[float] records['tan_err'].extend(tan_err.cpu().tolist()) # List[float] # Compute collision counts. collision_ped, collision_veh, collision_map = get_collision_rate(args, map_data, future_veh, veh_length, mask, pos_pred) records['collision_ped'].append(collision_ped) records['collision_veh'].append(collision_veh) records['collision_map'].append(collision_map) collision_ped_base, collision_veh_base, collision_map_base = get_collision_rate(args, map_data, future_veh, veh_length, mask, pos_true.unsqueeze(0)) records['collision_ped_base'].append(collision_ped_base) records['collision_veh_base'].append(collision_veh_base) records['collision_map_base'].append(collision_map_base) collision_ped2, collision_veh2, collision_map2 = get_collision_rate2(args, map_data, future_veh, veh_length, mask, pos_pred, pos_true) records['collision_ped2'].append(collision_ped2) records['collision_veh2'].append(collision_veh2) records['collision_map2'].append(collision_map2) # Compute APD diversity. tmp = pos_pred[:, mask, :, :] # (S, valid{B*#pedestrian}, roll_step*pred_step, 2) tmp = (tmp[None, :, ...] - tmp[:, None, ...]).norm(dim=-1).mean(dim=-1) # (S, S, valid{B*#pedestrian}) apd = tmp.flatten(0, 1).sum(dim=0) / (S * (S - 1)) # (valid{B*#pedestrian},) records['apd'].extend(apd.cpu().tolist()) # List[float] # Compute trajectory length. trajlen = pos_true.diff(dim=-2).norm(dim=-1).sum(dim=-1)[mask] # (valid{B*#pedestrian}) records['trajlen'].extend(trajlen.cpu().tolist()) # List[float] # Count pedestrians and vehicles. records['ped_num'].extend(ped_length.cpu().tolist()) # List[int] records['veh_num'].extend(veh_length.cpu().tolist()) # List[int] # Record rollout time. records['rollout_time'].append(rollout_time) # List[float] test_timer.add('evaluate', n=0) records_list.append(records) _logger.info(tag2ansi( f"[#66CCFF][Epoch {epoch}/{args.epochs}] Eval on {loader.dataset.name}: " f"[bold underline orange]Accuracy={1 - np.mean(records['ade']) / np.mean(records['trajlen']):.2%}[reset], " f"[#66CCFF]Loss={np.mean(records['loss']):.4f}, " f"[#66CCFF]ADE={np.mean(records['ade']):.4f}, " f"[#66CCFF]FDE={np.mean(records['fde']):.4f}, " f"[#66CCFF]X_ERROR (normal)={np.nanmean(records['norm_err']):.4f}, " f"[#66CCFF]Y_ERROR (tangential)={np.nanmean(records['tan_err']):.4f}, " f"[#66CCFF]Collision-Ped={np.mean(records['collision_ped']) - (base:= np.mean(records['collision_ped_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Veh={np.mean(records['collision_veh']) - (base:= np.mean(records['collision_veh_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Map={np.mean(records['collision_map']) - (base:= np.mean(records['collision_map_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Ped2={np.mean(records['collision_ped2']):.2%}, " f"[#66CCFF]Collision-Veh2={np.mean(records['collision_veh2']):.2%}, " f"[#66CCFF]Collision-Map2={np.mean(records['collision_map2']):.2%}, " f"[#66CCFF]APD={np.mean(records['apd']):.4f}, " f"[#66CCFF]AvgLen={np.mean(records['trajlen']):.4f}, " f"[#66CCFF]PedNum={np.mean(records['ped_num']):.1f}, " f"[#66CCFF]VehNum={np.mean(records['veh_num']):.1f}, " f"[#66CCFF]RolloutTime={np.mean(records['rollout_time'])*1000:.2f}ms " f"([bold underline orange]FPS={1/np.mean(records['rollout_time']):.2f} Hz[reset])" )) all_records = { 'epoch': epoch, 'dataset_class': [type(loader.dataset).__name__.removesuffix('Dataset') for loader in test_loaders], 'dataset_names': [loader.dataset.name for loader in test_loaders], 'sample_nums': [len(loader.dataset) for loader in test_loaders], } for records in records_list: for k, v in records.items(): if k not in all_records: all_records[k] = [] all_records[k].append(np.mean(v)) w = np.array(all_records['sample_nums'], dtype=float) w /= w.sum() all_records['accuracy'] = 1 - np.sum(w * all_records['ade']) / np.sum(w * all_records['trajlen']) all_records['unweighted_accuracy'] = 1 - np.mean(all_records['ade']) / np.mean(all_records['trajlen']) _logger.note(tag2ansi( f"[#66CCFF][Epoch {epoch}/{args.epochs}] Overall: " f"[bold underline orange]Accuracy={all_records['accuracy']:.2%}[reset] (unweighted={all_records['unweighted_accuracy']:.2%}), " f"[#66CCFF]Loss={np.sum(w * all_records['loss']):.4f}, " f"[#66CCFF]ADE={np.sum(w * all_records['ade']):.4f}, " f"[#66CCFF]FDE={np.sum(w * all_records['fde']):.4f}, " f"[#66CCFF]X_ERROR (normal)={np.nansum(w * all_records['norm_err']) / np.sum(w * np.isfinite(all_records['norm_err'])):.4f}, " f"[#66CCFF]Y_ERROR (tangential)={np.nansum(w * all_records['tan_err']) / np.sum(w * np.isfinite(all_records['tan_err'])):.4f}, " f"[#66CCFF]Collision-Ped={np.sum(w * all_records['collision_ped']) - (base := np.sum(w * all_records['collision_ped_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Veh={np.sum(w * all_records['collision_veh']) - (base := np.sum(w * all_records['collision_veh_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Map={np.sum(w * all_records['collision_map']) - (base := np.sum(w * all_records['collision_map_base'])):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Ped2={np.sum(w * all_records['collision_ped2']):.2%}, " f"[#66CCFF]Collision-Veh2={np.sum(w * all_records['collision_veh2']):.2%}, " f"[#66CCFF]Collision-Map2={np.sum(w * all_records['collision_map2']):.2%}, " f"[#66CCFF]APD={np.sum(w * all_records['apd']):.4f}, " f"[#66CCFF]AvgLen={np.sum(w * all_records['trajlen']):.4f}, " f"[#66CCFF]PedNum={np.sum(w * all_records['ped_num']):.4f}, " f"[#66CCFF]VehNum={np.sum(w * all_records['veh_num']):.4f}, " f"[#66CCFF]RolloutTime={np.mean(all_records['rollout_time'])*1000:.2}ms " f"([bold underline orange]FPS={1/np.mean(all_records['rollout_time']):.2f} Hz[reset]), " f"[#66CCFF]Time={test_timer}" )) if len(set(all_records['dataset_class'])) > 1: for klass in sorted(list(set(all_records['dataset_class']))): idxs = [i for i, k in enumerate(all_records['dataset_class']) if k == klass] ade = np.array([all_records['ade'][i] for i in idxs]) fde = np.array([all_records['fde'][i] for i in idxs]) norm_err = np.array([all_records['norm_err'][i] for i in idxs]) tan_err = np.array([all_records['tan_err'][i] for i in idxs]) trajlen = np.array([all_records['trajlen'][i] for i in idxs]) ped_num = np.array([all_records['ped_num'][i] for i in idxs]) veh_num = np.array([all_records['veh_num'][i] for i in idxs]) collision_ped = np.array([all_records['collision_ped'][i] for i in idxs]) collision_veh = np.array([all_records['collision_veh'][i] for i in idxs]) collision_map = np.array([all_records['collision_map'][i] for i in idxs]) collision_ped_base = np.array([all_records['collision_ped_base'][i] for i in idxs]) collision_veh_base = np.array([all_records['collision_veh_base'][i] for i in idxs]) collision_map_base = np.array([all_records['collision_map_base'][i] for i in idxs]) collision_ped2 = np.array([all_records['collision_ped2'][i] for i in idxs]) collision_veh2 = np.array([all_records['collision_veh2'][i] for i in idxs]) collision_map2 = np.array([all_records['collision_map2'][i] for i in idxs]) rollout_time = np.array([all_records['rollout_time'][i] for i in idxs]) apd = np.array([all_records['apd'][i] for i in idxs]) w = np.array([all_records['sample_nums'][i] for i in idxs], dtype=float) w /= w.sum() acc = 1 - np.sum(w * ade) / np.sum(w * trajlen) _logger.note(tag2ansi( f"[#66CCFF][Epoch {epoch}/{args.epochs}] Overall on {klass} datasets: " f"[bold underline orange]Accuracy={acc:.2%}[reset], " f"[#66CCFF]ADE={np.sum(w * ade):.4f}, " f"[#66CCFF]FDE={np.sum(w * fde):.4f}, " f"[#66CCFF]X_ERROR (normal)={np.nansum(w * norm_err) / np.sum(w * np.isfinite(norm_err)):.4f}, " f"[#66CCFF]Y_ERROR (tangential)={np.nansum(w * tan_err) / np.sum(w * np.isfinite(tan_err)):.4f}, " f"[#66CCFF]Collision-Ped={np.sum(w * collision_ped) - (base := np.sum(w * collision_ped_base)):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Veh={np.sum(w * collision_veh) - (base := np.sum(w * collision_veh_base)):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Map={np.sum(w * collision_map) - (base := np.sum(w * collision_map_base)):.2%} (+{base:.2%}), " f"[#66CCFF]Collision-Ped2={np.sum(w * collision_ped2):.2%}, " f"[#66CCFF]Collision-Veh2={np.sum(w * collision_veh2):.2%}, " f"[#66CCFF]Collision-Map2={np.sum(w * collision_map2):.2%}, " f"[#66CCFF]APD={np.sum(w * apd):.4f}, " f"[#66CCFF]AvgLen={np.sum(w * trajlen):.4f}, " f"[#66CCFF]PedNum={np.sum(w * ped_num):.4f}, " f"[#66CCFF]VehNum={np.sum(w * veh_num):.4f}, " f"[#66CCFF]RolloutTime={np.mean(rollout_time)*1000:.2f}ms " f"([bold underline orange]FPS={1/np.mean(rollout_time):.2f} Hz[reset])" )) return all_records