1
1
from __future__ import annotations
2
2
3
+ import warnings
3
4
from decimal import Decimal
4
5
from typing import TYPE_CHECKING
5
6
10
11
VectorMap ,
11
12
)
12
13
from trajdata .maps .map_kdtree import MapElementKDTree
14
+ from trajdata .maps .map_strtree import MapElementSTRTree
13
15
14
16
import pickle
15
17
from math import ceil , floor
@@ -654,7 +656,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo
654
656
655
657
def get_traffic_light_status_dict (
656
658
self , desired_dt : Optional [float ] = None
657
- ) -> Dict [Tuple [int , int ], TrafficLightStatus ]:
659
+ ) -> Dict [Tuple [str , int ], TrafficLightStatus ]:
658
660
"""
659
661
Returns dict mapping Lane Id, scene_ts to traffic light status for the
660
662
particular scene. If data doesn't exist for the current dt, interpolates and
@@ -704,18 +706,20 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool:
704
706
@staticmethod
705
707
def get_map_paths (
706
708
cache_path : Path , env_name : str , map_name : str , resolution : float
707
- ) -> Tuple [Path , Path , Path , Path , Path ]:
709
+ ) -> Tuple [Path , Path , Path , Path , Path , Path ]:
708
710
maps_path : Path = DataFrameCache .get_maps_path (cache_path , env_name )
709
711
710
712
vector_map_path : Path = maps_path / f"{ map_name } .pb"
711
713
kdtrees_path : Path = maps_path / f"{ map_name } _kdtrees.dill"
714
+ rtrees_path : Path = maps_path / f"{ map_name } _rtrees.dill"
712
715
raster_map_path : Path = maps_path / f"{ map_name } _{ resolution :.2f} px_m.zarr"
713
716
raster_metadata_path : Path = maps_path / f"{ map_name } _{ resolution :.2f} px_m.dill"
714
717
715
718
return (
716
719
maps_path ,
717
720
vector_map_path ,
718
721
kdtrees_path ,
722
+ rtrees_path ,
719
723
raster_map_path ,
720
724
raster_metadata_path ,
721
725
)
@@ -728,13 +732,19 @@ def is_map_cached(
728
732
maps_path ,
729
733
vector_map_path ,
730
734
kdtrees_path ,
735
+ rtrees_path ,
731
736
raster_map_path ,
732
737
raster_metadata_path ,
733
738
) = DataFrameCache .get_map_paths (cache_path , env_name , map_name , resolution )
739
+
740
+ # TODO(bivanovic): For now, rtrees are optional to have in the cache.
741
+ # In the future, they may be required (likely after we develop an
742
+ # incremental caching scheme or similar to handle additions like these).
734
743
return (
735
744
maps_path .exists ()
736
745
and vector_map_path .exists ()
737
746
and kdtrees_path .exists ()
747
+ # and rtrees_path.exists()
738
748
and raster_metadata_path .exists ()
739
749
and raster_map_path .exists ()
740
750
)
@@ -751,6 +761,7 @@ def finalize_and_cache_map(
751
761
maps_path ,
752
762
vector_map_path ,
753
763
kdtrees_path ,
764
+ rtrees_path ,
754
765
raster_map_path ,
755
766
raster_metadata_path ,
756
767
) = DataFrameCache .get_map_paths (
@@ -775,6 +786,10 @@ def finalize_and_cache_map(
775
786
with open (kdtrees_path , "wb" ) as f :
776
787
dill .dump (vector_map .search_kdtrees , f )
777
788
789
+ # Saving precomputed map element rtrees.
790
+ with open (rtrees_path , "wb" ) as f :
791
+ dill .dump (vector_map .search_rtrees , f )
792
+
778
793
# Saving the rasterized map data.
779
794
zarr .save (raster_map_path , rasterized_map .data )
780
795
@@ -814,7 +829,7 @@ def pad_map_patch(
814
829
return np .pad (patch , [(0 , 0 ), (pad_top , pad_bot ), (pad_left , pad_right )])
815
830
816
831
def load_kdtrees (self ) -> Dict [str , MapElementKDTree ]:
817
- _ , _ , kdtrees_path , _ , _ = DataFrameCache .get_map_paths (
832
+ _ , _ , kdtrees_path , _ , _ , _ = DataFrameCache .get_map_paths (
818
833
self .path , self .scene .env_name , self .scene .location , 0.0
819
834
)
820
835
@@ -840,6 +855,47 @@ def get_kdtrees(self, load_only_once: bool = True):
840
855
else :
841
856
return self ._kdtrees
842
857
858
+ def load_rtrees (self ) -> MapElementSTRTree :
859
+ _ , _ , _ , rtrees_path , _ , _ = DataFrameCache .get_map_paths (
860
+ self .path , self .scene .env_name , self .scene .location , 0.0
861
+ )
862
+
863
+ if not rtrees_path .exists ():
864
+ warnings .warn (
865
+ (
866
+ "Trying to load cached RTree encoding 2D Map elements, "
867
+ f"but { rtrees_path } does not exist. Earlier versions of "
868
+ "trajdata did not build and cache this RTree. If area queries "
869
+ "are needed, please rebuild the map cache (see "
870
+ "examples/preprocess_maps.py for an example of how to do this). "
871
+ "Otherwise, please ignore this warning."
872
+ ),
873
+ UserWarning ,
874
+ )
875
+ return None
876
+
877
+ with open (rtrees_path , "rb" ) as f :
878
+ rtrees : MapElementSTRTree = dill .load (f )
879
+
880
+ return rtrees
881
+
882
+ def get_rtrees (self , load_only_once : bool = True ):
883
+ """Loads and returns the rtrees object from the cache file.
884
+
885
+ Args:
886
+ load_only_once (bool): store the kdtree dictionary in self so that we
887
+ dont have to load it from the cache file more than once.
888
+ """
889
+ if self ._rtrees is None :
890
+ rtrees = self .load_rtrees ()
891
+ if load_only_once :
892
+ self ._rtrees = rtrees
893
+
894
+ return rtrees
895
+
896
+ else :
897
+ return self ._rtrees
898
+
843
899
def load_map_patch (
844
900
self ,
845
901
world_x : float ,
@@ -856,6 +912,7 @@ def load_map_patch(
856
912
maps_path ,
857
913
_ ,
858
914
_ ,
915
+ _ ,
859
916
raster_map_path ,
860
917
raster_metadata_path ,
861
918
) = DataFrameCache .get_map_paths (
0 commit comments