src.tasks.test_once module#
- src.tasks.test_once.get_xy_error(args, pos, veh, mask, pos_true, pos_pred, sample_idx, valid_idx)[source]#
- src.tasks.test_once.get_collision_rate(args, map_data, future_veh, veh_length, mask, pos_pred)[source]#
- src.tasks.test_once.get_collision_rate2(args, map_data, future_veh, veh_length, mask, pos_pred, pos_true)[source]#
- src.tasks.test_once.test_once(args, test_loaders: List[torch.utils.data.DataLoader], model: Model, criterion: torch.nn.Module, diffusion: DDPM, epoch: int) Dict[source]#
Run one evaluation epoch over the provided loaders.
- Parameters:
args – Global arguments.
test_loaders – Test dataloaders.
model – Model under evaluation.
criterion – Loss function.
diffusion – Diffusion module.
epoch – Current epoch index.
- Returns:
Evaluation metrics dictionary.
- Return type:
all_records