src.tasks.test_once module#

src.tasks.test_once.get_xy_error(args, pos, veh, mask, pos_true, pos_pred, sample_idx, valid_idx)[source]#
src.tasks.test_once.get_collision_rate(args, map_data, future_veh, veh_length, mask, pos_pred)[source]#
src.tasks.test_once.get_collision_rate2(args, map_data, future_veh, veh_length, mask, pos_pred, pos_true)[source]#
src.tasks.test_once.test_once(args, test_loaders: List[torch.utils.data.DataLoader], model: Model, criterion: torch.nn.Module, diffusion: DDPM, epoch: int) Dict[source]#

Run one evaluation epoch over the provided loaders.

Parameters:
  • args – Global arguments.

  • test_loaders – Test dataloaders.

  • model – Model under evaluation.

  • criterion – Loss function.

  • diffusion – Diffusion module.

  • epoch – Current epoch index.

Returns:

Evaluation metrics dictionary.

Return type:

all_records