Skip to content

Commit c1a9499

Browse files
committed
Adding quality of life improvements, map area querying, fixing nuPlan lane ID format, and the ability to cache the data index.
1 parent 748b8b1 commit c1a9499

27 files changed

+801
-159
lines changed

examples/preprocess_maps.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# @profile
55
def main():
66
dataset = UnifiedDataset(
7+
# TODO([email protected]) Remove lyft from default examples
78
desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"],
89
rebuild_maps=True,
910
data_dirs={ # Remember to change this to match your filesystem!

examples/simple_map_api_example.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77

88
from trajdata import MapAPI, VectorMap
9+
from trajdata.maps.vec_map_elements import MapElementType
10+
from trajdata.utils import map_utils
911

1012

1113
def main():
@@ -23,12 +25,16 @@ def main():
2325
}
2426

2527
start = time.perf_counter()
26-
vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}")
28+
vec_map: VectorMap = map_api.get_map(
29+
f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True
30+
)
2731
end = time.perf_counter()
2832
print(f"Map loading took {(end - start)*1000:.2f} ms")
2933

3034
start = time.perf_counter()
31-
vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}")
35+
vec_map: VectorMap = map_api.get_map(
36+
f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True
37+
)
3238
end = time.perf_counter()
3339
print(f"Repeated (cached in memory) map loading took {(end - start)*1000:.2f} ms")
3440

@@ -64,6 +70,35 @@ def main():
6470
end = time.perf_counter()
6571
print(f"Lane visualization took {(end - start)*1000:.2f} ms")
6672

73+
point = vec_map.lanes[lane_idx].center.xyz[0, :]
74+
point_raster = map_utils.transform_points(
75+
point[None, :], transf_mat=raster_from_world
76+
)
77+
ax.scatter(point_raster[:, 0], point_raster[:, 1])
78+
79+
print("Getting nearest road area...")
80+
start = time.perf_counter()
81+
area = vec_map.get_closest_area(point, elem_type=MapElementType.ROAD_AREA)
82+
end = time.perf_counter()
83+
print(f"Getting nearest area took {(end-start)*1000:.2f} ms")
84+
85+
raster_pts = map_utils.transform_points(area.exterior_polygon.xy, raster_from_world)
86+
ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=1.0, color="C0")
87+
88+
print("Getting road areas within 100m...")
89+
start = time.perf_counter()
90+
areas = vec_map.get_areas_within(
91+
point, elem_type=MapElementType.ROAD_AREA, dist=100.0
92+
)
93+
end = time.perf_counter()
94+
print(f"Getting areas within took {(end-start)*1000:.2f} ms")
95+
96+
for area in areas:
97+
raster_pts = map_utils.transform_points(
98+
area.exterior_polygon.xy, raster_from_world
99+
)
100+
ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=0.2, color="C1")
101+
67102
ax.axis("equal")
68103
ax.grid(None)
69104

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ classifiers = [
1313
"License :: OSI Approved :: Apache Software License",
1414
]
1515
name = "trajdata"
16-
version = "1.3.3"
16+
version = "1.4.0"
1717
authors = [{ name = "Boris Ivanovic", email = "[email protected]" }]
1818
description = "A unified interface to many trajectory forecasting datasets."
1919
readme = "README.md"
@@ -33,7 +33,8 @@ dependencies = [
3333
"geopandas>=0.13.2",
3434
"protobuf==3.19.4",
3535
"scipy>=1.9.0",
36-
"opencv-python>=4.5.0"
36+
"opencv-python>=4.5.0",
37+
"shapely>=2.0.0",
3738
]
3839

3940
[project.optional-dependencies]

src/trajdata/caching/df_cache.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from decimal import Decimal
45
from typing import TYPE_CHECKING
56

@@ -10,6 +11,7 @@
1011
VectorMap,
1112
)
1213
from trajdata.maps.map_kdtree import MapElementKDTree
14+
from trajdata.maps.map_strtree import MapElementSTRTree
1315

1416
import pickle
1517
from math import ceil, floor
@@ -654,7 +656,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo
654656

655657
def get_traffic_light_status_dict(
656658
self, desired_dt: Optional[float] = None
657-
) -> Dict[Tuple[int, int], TrafficLightStatus]:
659+
) -> Dict[Tuple[str, int], TrafficLightStatus]:
658660
"""
659661
Returns dict mapping Lane Id, scene_ts to traffic light status for the
660662
particular scene. If data doesn't exist for the current dt, interpolates and
@@ -704,18 +706,20 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool:
704706
@staticmethod
705707
def get_map_paths(
706708
cache_path: Path, env_name: str, map_name: str, resolution: float
707-
) -> Tuple[Path, Path, Path, Path, Path]:
709+
) -> Tuple[Path, Path, Path, Path, Path, Path]:
708710
maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name)
709711

710712
vector_map_path: Path = maps_path / f"{map_name}.pb"
711713
kdtrees_path: Path = maps_path / f"{map_name}_kdtrees.dill"
714+
rtrees_path: Path = maps_path / f"{map_name}_rtrees.dill"
712715
raster_map_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.zarr"
713716
raster_metadata_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.dill"
714717

715718
return (
716719
maps_path,
717720
vector_map_path,
718721
kdtrees_path,
722+
rtrees_path,
719723
raster_map_path,
720724
raster_metadata_path,
721725
)
@@ -728,13 +732,19 @@ def is_map_cached(
728732
maps_path,
729733
vector_map_path,
730734
kdtrees_path,
735+
rtrees_path,
731736
raster_map_path,
732737
raster_metadata_path,
733738
) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution)
739+
740+
# TODO(bivanovic): For now, rtrees are optional to have in the cache.
741+
# In the future, they may be required (likely after we develop an
742+
# incremental caching scheme or similar to handle additions like these).
734743
return (
735744
maps_path.exists()
736745
and vector_map_path.exists()
737746
and kdtrees_path.exists()
747+
# and rtrees_path.exists()
738748
and raster_metadata_path.exists()
739749
and raster_map_path.exists()
740750
)
@@ -751,6 +761,7 @@ def finalize_and_cache_map(
751761
maps_path,
752762
vector_map_path,
753763
kdtrees_path,
764+
rtrees_path,
754765
raster_map_path,
755766
raster_metadata_path,
756767
) = DataFrameCache.get_map_paths(
@@ -775,6 +786,10 @@ def finalize_and_cache_map(
775786
with open(kdtrees_path, "wb") as f:
776787
dill.dump(vector_map.search_kdtrees, f)
777788

789+
# Saving precomputed map element rtrees.
790+
with open(rtrees_path, "wb") as f:
791+
dill.dump(vector_map.search_rtrees, f)
792+
778793
# Saving the rasterized map data.
779794
zarr.save(raster_map_path, rasterized_map.data)
780795

@@ -814,7 +829,7 @@ def pad_map_patch(
814829
return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)])
815830

816831
def load_kdtrees(self) -> Dict[str, MapElementKDTree]:
817-
_, _, kdtrees_path, _, _ = DataFrameCache.get_map_paths(
832+
_, _, kdtrees_path, _, _, _ = DataFrameCache.get_map_paths(
818833
self.path, self.scene.env_name, self.scene.location, 0.0
819834
)
820835

@@ -840,6 +855,47 @@ def get_kdtrees(self, load_only_once: bool = True):
840855
else:
841856
return self._kdtrees
842857

858+
def load_rtrees(self) -> MapElementSTRTree:
859+
_, _, _, rtrees_path, _, _ = DataFrameCache.get_map_paths(
860+
self.path, self.scene.env_name, self.scene.location, 0.0
861+
)
862+
863+
if not rtrees_path.exists():
864+
warnings.warn(
865+
(
866+
"Trying to load cached RTree encoding 2D Map elements, "
867+
f"but {rtrees_path} does not exist. Earlier versions of "
868+
"trajdata did not build and cache this RTree. If area queries "
869+
"are needed, please rebuild the map cache (see "
870+
"examples/preprocess_maps.py for an example of how to do this). "
871+
"Otherwise, please ignore this warning."
872+
),
873+
UserWarning,
874+
)
875+
return None
876+
877+
with open(rtrees_path, "rb") as f:
878+
rtrees: MapElementSTRTree = dill.load(f)
879+
880+
return rtrees
881+
882+
def get_rtrees(self, load_only_once: bool = True):
883+
"""Loads and returns the rtrees object from the cache file.
884+
885+
Args:
886+
load_only_once (bool): store the kdtree dictionary in self so that we
887+
dont have to load it from the cache file more than once.
888+
"""
889+
if self._rtrees is None:
890+
rtrees = self.load_rtrees()
891+
if load_only_once:
892+
self._rtrees = rtrees
893+
894+
return rtrees
895+
896+
else:
897+
return self._rtrees
898+
843899
def load_map_patch(
844900
self,
845901
world_x: float,
@@ -856,6 +912,7 @@ def load_map_patch(
856912
maps_path,
857913
_,
858914
_,
915+
_,
859916
raster_map_path,
860917
raster_metadata_path,
861918
) = DataFrameCache.get_map_paths(

src/trajdata/caching/scene_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo
168168

169169
def get_traffic_light_status_dict(
170170
self, desired_dt: Optional[float] = None
171-
) -> Dict[Tuple[int, int], TrafficLightStatus]:
171+
) -> Dict[Tuple[str, int], TrafficLightStatus]:
172172
"""Returns lookup table for traffic light status in the current scene
173173
lane_id, scene_ts -> TrafficLightStatus"""
174174
raise NotImplementedError()

src/trajdata/data_structures/batch_element.py

-3
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,6 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray:
441441
**vector_map_params if vector_map_params is not None else None,
442442
)
443443

444-
self.scene_id = scene_time.scene.name
445-
446444
### ROBOT DATA ###
447445
self.robot_future_np: Optional[StateArray] = None
448446

@@ -506,7 +504,6 @@ def get_agents_future(
506504
future_sec: Tuple[Optional[float], Optional[float]],
507505
nearby_agents: List[AgentMetadata],
508506
) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]:
509-
510507
(
511508
agent_futures,
512509
agent_future_extents,

src/trajdata/data_structures/collation.py

-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def raster_map_collate_fn_scene(
182182
max_agent_num: Optional[int] = None,
183183
pad_value: Any = np.nan,
184184
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
185-
186185
if batch_elems[0].map_patches is None:
187186
return None, None, None, None
188187

src/trajdata/data_structures/scene_tag.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Set, Tuple
23

34

@@ -8,6 +9,9 @@ def __init__(self, tag_tuple: Tuple[str, ...]) -> None:
89
def contains(self, query: Set[str]) -> bool:
910
return query.issubset(self._tag_tuple)
1011

12+
def matches_any(self, regex: re.Pattern) -> bool:
13+
return any(regex.search(x) is not None for x in self._tag_tuple)
14+
1115
def __contains__(self, item) -> bool:
1216
return item in self._tag_tuple
1317

0 commit comments

Comments
 (0)