Source code for src.dataset.orca_dataset

import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from argparse import Namespace
from PIL import Image, ImageOps
from .base_dataset import BaseDataset, EmptyDatasetError, RasterizedMap
from ..utils.homography import calc_homography_mat, affine_transformation, image_to_world
from typing import List

_logger = logging.getLogger(__name__)

[docs] class ORCADataset(BaseDataset): """ ORCA (Optimal Reciprocal Collision Avoidance) simulation dataset loader. Used for loading synthetic trajectory data generated by the ORCA algorithm, typically for benchmarking or pretraining. """ raw_fps = None # Depends on the data file.
[docs] @classmethod def load_data(cls, args: Namespace, data_path: str, with_shape=False) -> "ORCADataset": """ Load a single ORCA simulation scene. Read the frame rate from `fps.txt` and load the trajectory data from csv. If `map.png` exists, load it and compute the corresponding map in world coordinates. Args: args (Namespace): Global arguments. data_path (str): Path to `data.csv` or `data.csv.gz`. Returns: ORCADataset: Initialized dataset instance. """ data_path = Path(data_path) name = data_path.parent.name ## Check cache. cache_path = cls._make_cache_path(args, str(data_path), name) if args.cache_dataset and cache_path.exists(): _logger.info(f"Loading cached dataset from {cache_path}") try: dataset = cls.load_cache(cache_path) if len(dataset) == 0: # If it can be read but is empty, regenerating it would still be empty, so fail early. raise EmptyDatasetError(f"Cached dataset {cache_path} is empty.") cls.collate_fn([dataset[0]]) # Sanity-check that the dataset is usable. return dataset except Exception as e: _logger.error(f"Failed to use cached dataset {cache_path}: {e}") if not data_path.exists(): raise FileNotFoundError(f"Data path {data_path} not found.") ## Read raw data. df_data = pd.read_csv(data_path).assign(type='pedestrian') ## Resample data. raw_fps = float((data_path.parent / 'fps.txt').read_text().strip()) df_data = cls.resample_dataframe(df_data, raw_fps=raw_fps, target_fps=args.fps) ## Load map. if (image_path := data_path.parent / "map.png").exists(): image = Image.open(image_path).convert('L') # Invert black and white so obstacles become white and roads become black. image = ImageOps.invert(image) # Downscale the image to avoid excessive memory usage inside `image_to_world`. h, w = image.size total_pixels = h * w max_pixels = 1e5 if total_pixels > max_pixels: scale = (max_pixels / total_pixels) ** 0.5 image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS) image = np.array(image) / 255.0 # (H, W), first dimension downward, second dimension rightward. h, w = image.shape xmin0, xmax0, ymin0, ymax0 = np.loadtxt(data_path.parent / 'map_range.txt') H = calc_homography_mat( np.array([[0, 0], [h, 0], [0, w], [h, w]]), np.array([[xmin0, ymax0], [xmin0, ymin0], [xmax0, ymax0], [xmax0, ymin0]]), ) map, xmin, xmax, ymin, ymax = image_to_world(image, H, dot_per_meter=args.dot_per_meter) # First dimension rightward and second upward, i.e. xy coordinates. map_data = RasterizedMap(map=map, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) ## Normalize coordinates. df_data, map_data = cls.normalize_xy(df_data, map_data) ## Build dataset object. dataset = cls(name=name, args=args, df_data=df_data, map_data=map_data) ## Save cache. cache_path = cls._make_cache_path(args, str(data_path), name) _logger.info(f"Caching dataset to {cache_path}") cls.save_cache(dataset, cache_path) return dataset
[docs] @classmethod def load_data_batch(cls, args: Namespace, data_path: str, show_tqdm=True) -> List["ORCADataset"]: """Batch-load ORCA datasets.""" ## Check cache. name = '-'.join(Path(data_path).relative_to('./data').parts) cache_path = Path('./data/.cache') / Path(name).with_suffix(".pkl") try: assert args.cache_dataset, f"Cache disabled" assert cache_path.exists(), f"Cache {cache_path} not found" _logger.info(f"Loading cached dataset-list from {cache_path}") files = cls.load_cache(cache_path) except Exception as e: _logger.info(f"Failed to load cached dataset-list from {cache_path} since: {e}") data_path = Path(data_path) if data_path.is_dir(): files = list(sorted(data_path.glob("**/data.csv.gz"))) elif "*" in str(data_path): if data_path.is_absolute(): data_path = data_path.relative_to(".") files = list(sorted(Path(".").glob(data_path))) else: files = [data_path] _logger.info(f"Caching dataset-list to {cache_path}") cls.save_cache(files, cache_path) datasets = [] pbar = tqdm(files, disable=not show_tqdm, desc="Loading ORCA datasets") for file in pbar: try: pbar.set_postfix_str(file.parent.parent.name + "/" + file.parent.name) dataset = cls.load_data(args, file) if len(dataset.samples) == 0: raise ValueError(f"Dataset {file} has no samples, skipping.") datasets.append(dataset) datasets[-1].path = str(file) except Exception as e: _logger.error(f"Failed to load {file}: {e}") continue return datasets