11from __future__ import annotations
22
3+ import warnings
34from decimal import Decimal
45from typing import TYPE_CHECKING
56
1011 VectorMap ,
1112 )
1213 from trajdata .maps .map_kdtree import MapElementKDTree
14+ from trajdata .maps .map_strtree import MapElementSTRTree
1315
1416import pickle
1517from math import ceil , floor
@@ -654,7 +656,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo
654656
655657 def get_traffic_light_status_dict (
656658 self , desired_dt : Optional [float ] = None
657- ) -> Dict [Tuple [int , int ], TrafficLightStatus ]:
659+ ) -> Dict [Tuple [str , int ], TrafficLightStatus ]:
658660 """
659661 Returns dict mapping Lane Id, scene_ts to traffic light status for the
660662 particular scene. If data doesn't exist for the current dt, interpolates and
@@ -704,18 +706,20 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool:
704706 @staticmethod
705707 def get_map_paths (
706708 cache_path : Path , env_name : str , map_name : str , resolution : float
707- ) -> Tuple [Path , Path , Path , Path , Path ]:
709+ ) -> Tuple [Path , Path , Path , Path , Path , Path ]:
708710 maps_path : Path = DataFrameCache .get_maps_path (cache_path , env_name )
709711
710712 vector_map_path : Path = maps_path / f"{ map_name } .pb"
711713 kdtrees_path : Path = maps_path / f"{ map_name } _kdtrees.dill"
714+ rtrees_path : Path = maps_path / f"{ map_name } _rtrees.dill"
712715 raster_map_path : Path = maps_path / f"{ map_name } _{ resolution :.2f} px_m.zarr"
713716 raster_metadata_path : Path = maps_path / f"{ map_name } _{ resolution :.2f} px_m.dill"
714717
715718 return (
716719 maps_path ,
717720 vector_map_path ,
718721 kdtrees_path ,
722+ rtrees_path ,
719723 raster_map_path ,
720724 raster_metadata_path ,
721725 )
@@ -728,13 +732,19 @@ def is_map_cached(
728732 maps_path ,
729733 vector_map_path ,
730734 kdtrees_path ,
735+ rtrees_path ,
731736 raster_map_path ,
732737 raster_metadata_path ,
733738 ) = DataFrameCache .get_map_paths (cache_path , env_name , map_name , resolution )
739+
740+ # TODO(bivanovic): For now, rtrees are optional to have in the cache.
741+ # In the future, they may be required (likely after we develop an
742+ # incremental caching scheme or similar to handle additions like these).
734743 return (
735744 maps_path .exists ()
736745 and vector_map_path .exists ()
737746 and kdtrees_path .exists ()
747+ # and rtrees_path.exists()
738748 and raster_metadata_path .exists ()
739749 and raster_map_path .exists ()
740750 )
@@ -751,6 +761,7 @@ def finalize_and_cache_map(
751761 maps_path ,
752762 vector_map_path ,
753763 kdtrees_path ,
764+ rtrees_path ,
754765 raster_map_path ,
755766 raster_metadata_path ,
756767 ) = DataFrameCache .get_map_paths (
@@ -775,6 +786,10 @@ def finalize_and_cache_map(
775786 with open (kdtrees_path , "wb" ) as f :
776787 dill .dump (vector_map .search_kdtrees , f )
777788
789+ # Saving precomputed map element rtrees.
790+ with open (rtrees_path , "wb" ) as f :
791+ dill .dump (vector_map .search_rtrees , f )
792+
778793 # Saving the rasterized map data.
779794 zarr .save (raster_map_path , rasterized_map .data )
780795
@@ -814,7 +829,7 @@ def pad_map_patch(
814829 return np .pad (patch , [(0 , 0 ), (pad_top , pad_bot ), (pad_left , pad_right )])
815830
816831 def load_kdtrees (self ) -> Dict [str , MapElementKDTree ]:
817- _ , _ , kdtrees_path , _ , _ = DataFrameCache .get_map_paths (
832+ _ , _ , kdtrees_path , _ , _ , _ = DataFrameCache .get_map_paths (
818833 self .path , self .scene .env_name , self .scene .location , 0.0
819834 )
820835
@@ -840,6 +855,47 @@ def get_kdtrees(self, load_only_once: bool = True):
840855 else :
841856 return self ._kdtrees
842857
858+ def load_rtrees (self ) -> MapElementSTRTree :
859+ _ , _ , _ , rtrees_path , _ , _ = DataFrameCache .get_map_paths (
860+ self .path , self .scene .env_name , self .scene .location , 0.0
861+ )
862+
863+ if not rtrees_path .exists ():
864+ warnings .warn (
865+ (
866+ "Trying to load cached RTree encoding 2D Map elements, "
867+ f"but { rtrees_path } does not exist. Earlier versions of "
868+ "trajdata did not build and cache this RTree. If area queries "
869+ "are needed, please rebuild the map cache (see "
870+ "examples/preprocess_maps.py for an example of how to do this). "
871+ "Otherwise, please ignore this warning."
872+ ),
873+ UserWarning ,
874+ )
875+ return None
876+
877+ with open (rtrees_path , "rb" ) as f :
878+ rtrees : MapElementSTRTree = dill .load (f )
879+
880+ return rtrees
881+
882+ def get_rtrees (self , load_only_once : bool = True ):
883+ """Loads and returns the rtrees object from the cache file.
884+
885+ Args:
886+ load_only_once (bool): store the kdtree dictionary in self so that we
887+ dont have to load it from the cache file more than once.
888+ """
889+ if self ._rtrees is None :
890+ rtrees = self .load_rtrees ()
891+ if load_only_once :
892+ self ._rtrees = rtrees
893+
894+ return rtrees
895+
896+ else :
897+ return self ._rtrees
898+
843899 def load_map_patch (
844900 self ,
845901 world_x : float ,
@@ -856,6 +912,7 @@ def load_map_patch(
856912 maps_path ,
857913 _ ,
858914 _ ,
915+ _ ,
859916 raster_map_path ,
860917 raster_metadata_path ,
861918 ) = DataFrameCache .get_map_paths (
0 commit comments