1010import json
1111import logging
1212import os
13+
1314import urllib
1415from dataclasses import dataclass , Field , field
1516from typing import (
3738 FrameDataBuilder , # noqa
3839 FrameDataBuilderBase ,
3940)
41+
4042from pytorch3d .implicitron .tools .config import (
4143 registry ,
4244 ReplaceableBase ,
4345 run_auto_creation ,
4446)
45- from sqlalchemy .orm import Session
47+ from sqlalchemy .orm import scoped_session , Session , sessionmaker
4648
4749from .orm_types import SqlFrameAnnotation , SqlSequenceAnnotation
4850
@@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
9193 engine verbatim. Don’t expose it to end users of your application!
9294 pick_categories: Restrict the dataset to the given list of categories.
9395 pick_sequences: A Sequence of sequence names to restrict the dataset to.
96+ pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
9497 exclude_sequences: A Sequence of the names of the sequences to exclude.
9598 limit_sequences_per_category_to: Limit the dataset to the first up to N
9699 sequences within each category (applies after all other sequence filters
@@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
105108 more frames than that; applied after other frame-level filters.
106109 seed: The seed of the random generator sampling `n_frames_per_sequence`
107110 random frames per sequence.
111+ preload_metadata: If True, the metadata is preloaded into memory.
112+ precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
113+ scoped_session: If True, allows different parts of the code to share
114+ a global session to access the database.
108115 """
109116
110117 frame_annotations_type : ClassVar [Type [SqlFrameAnnotation ]] = SqlFrameAnnotation
@@ -123,13 +130,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
123130 pick_categories : Tuple [str , ...] = ()
124131
125132 pick_sequences : Tuple [str , ...] = ()
133+ pick_sequences_sql_clause : Optional [str ] = None
126134 exclude_sequences : Tuple [str , ...] = ()
127135 limit_sequences_per_category_to : int = 0
128136 limit_sequences_to : int = 0
129137 limit_to : int = 0
130138 n_frames_per_sequence : int = - 1
131139 seed : int = 0
132140 remove_empty_masks_poll_whole_table_threshold : int = 300_000
141+ preload_metadata : bool = False
142+ precompute_seq_to_idx : bool = False
133143 # we set it manually in the constructor
134144 _index : pd .DataFrame = field (init = False , metadata = {"omegaconf_ignore" : True })
135145 _sql_engine : sa .engine .Engine = field (
@@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
142152 frame_data_builder : FrameDataBuilderBase # pyre-ignore[13]
143153 frame_data_builder_class_type : str = "FrameDataBuilder"
144154
155+ scoped_session : bool = False
156+
145157 def __post_init__ (self ) -> None :
146158 if sa .__version__ < "2.0" :
147159 raise ImportError ("This class requires SQL Alchemy 2.0 or later" )
@@ -169,6 +181,9 @@ def __post_init__(self) -> None:
169181 f"sqlite:///file:{ urllib .parse .quote (self .sqlite_metadata_file )} ?mode=ro&uri=true"
170182 )
171183
184+ if self .preload_metadata :
185+ self ._sql_engine = self ._preload_database (self ._sql_engine )
186+
172187 sequences = self ._get_filtered_sequences_if_any ()
173188
174189 if self .subsets :
@@ -192,6 +207,20 @@ def __post_init__(self) -> None:
192207
193208 logger .info (str (self ))
194209
210+ if self .scoped_session :
211+ self ._session_factory = sessionmaker (bind = self ._sql_engine ) # pyre-ignore
212+
213+ if self .precompute_seq_to_idx :
214+ # This is deprecated and will be removed in the future.
215+ # After we backport https://github.com/facebookresearch/uco3d/pull/3
216+ logger .warning (
217+ "Using precompute_seq_to_idx is deprecated and will be removed in the future."
218+ )
219+ self ._index ["rowid" ] = np .arange (len (self ._index ))
220+ groupby = self ._index .groupby ("sequence_name" , sort = False )["rowid" ]
221+ self ._seq_to_indices = dict (groupby .apply (list )) # pyre-ignore
222+ del self ._index ["rowid" ]
223+
195224 def __len__ (self ) -> int :
196225 return len (self ._index )
197226
@@ -252,9 +281,15 @@ def _get_item(
252281 seq_stmt = sa .select (self .sequence_annotations_type ).where (
253282 self .sequence_annotations_type .sequence_name == seq
254283 )
255- with Session (self ._sql_engine ) as session :
256- entry = session .scalars (stmt ).one ()
257- seq_metadata = session .scalars (seq_stmt ).one ()
284+ if self .scoped_session :
285+ # pyre-ignore
286+ with scoped_session (self ._session_factory )() as session :
287+ entry = session .scalars (stmt ).one ()
288+ seq_metadata = session .scalars (seq_stmt ).one ()
289+ else :
290+ with Session (self ._sql_engine ) as session :
291+ entry = session .scalars (stmt ).one ()
292+ seq_metadata = session .scalars (seq_stmt ).one ()
258293
259294 assert entry .image .path == self ._index .loc [(seq , frame ), "_image_path" ]
260295
@@ -363,6 +398,20 @@ def sequence_frames_in_order(
363398
364399 yield from index_slice .itertuples (index = False )
365400
401+ # override
402+ def sequence_indices_in_order (
403+ self , seq_name : str , subset_filter : Optional [Sequence [str ]] = None
404+ ) -> Iterator [int ]:
405+ """Same as `sequence_frames_in_order` but returns the iterator over
406+ only dataset indices.
407+ """
408+ if self .precompute_seq_to_idx and subset_filter is None :
409+ # pyre-ignore
410+ yield from self ._seq_to_indices [seq_name ]
411+ else :
412+ for _ , _ , idx in self .sequence_frames_in_order (seq_name , subset_filter ):
413+ yield idx
414+
366415 # override
367416 def get_eval_batches (self ) -> Optional [List [Any ]]:
368417 """
@@ -396,11 +445,35 @@ def is_filtered(self) -> bool:
396445 or self .limit_sequences_to > 0
397446 or self .limit_sequences_per_category_to > 0
398447 or len (self .pick_sequences ) > 0
448+ or self .pick_sequences_sql_clause is not None
399449 or len (self .exclude_sequences ) > 0
400450 or len (self .pick_categories ) > 0
401451 or self .n_frames_per_sequence > 0
402452 )
403453
454+ def _preload_database (
455+ self , source_engine : sa .engine .base .Engine
456+ ) -> sa .engine .base .Engine :
457+ destination_engine = sa .create_engine ("sqlite:///:memory:" )
458+ metadata = sa .MetaData ()
459+ metadata .reflect (bind = source_engine )
460+ metadata .create_all (bind = destination_engine )
461+
462+ with source_engine .connect () as source_conn :
463+ with destination_engine .connect () as destination_conn :
464+ for table_obj in metadata .tables .values ():
465+ # Select all rows from the source table
466+ source_rows = source_conn .execute (table_obj .select ())
467+
468+ # Insert rows into the destination table
469+ for row in source_rows :
470+ destination_conn .execute (table_obj .insert ().values (row ))
471+
472+ # Commit the changes for each table
473+ destination_conn .commit ()
474+
475+ return destination_engine
476+
404477 def _get_filtered_sequences_if_any (self ) -> Optional [pd .Series ]:
405478 # maximum possible filter (if limit_sequences_per_category_to == 0):
406479 # WHERE category IN 'self.pick_categories'
@@ -413,6 +486,9 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
413486 * self ._get_pick_filters (),
414487 * self ._get_exclude_filters (),
415488 ]
489+ if self .pick_sequences_sql_clause :
490+ print ("Applying the custom SQL clause." )
491+ where_conditions .append (sa .text (self .pick_sequences_sql_clause ))
416492
417493 def add_where (stmt ):
418494 return stmt .where (* where_conditions ) if where_conditions else stmt
@@ -749,9 +825,15 @@ def _get_frame_no_coalesced_ts_by_row_indices(
749825 self .frame_annotations_type .sequence_name == seq_name ,
750826 self .frame_annotations_type .frame_number .in_ (frames ),
751827 )
828+ frame_no_ts = None
752829
753- with self ._sql_engine .connect () as connection :
754- frame_no_ts = pd .read_sql_query (stmt , connection )
830+ if self .scoped_session :
831+ stmt_text = str (stmt .compile (compile_kwargs = {"literal_binds" : True }))
832+ with scoped_session (self ._session_factory )() as session : # pyre-ignore
833+ frame_no_ts = pd .read_sql_query (stmt_text , session .connection ())
834+ else :
835+ with self ._sql_engine .connect () as connection :
836+ frame_no_ts = pd .read_sql_query (stmt , connection )
755837
756838 if len (frame_no_ts ) != len (index_slice ):
757839 raise ValueError (
0 commit comments