@@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
936936 n_compounds : int
937937 The number of individual compounds.
938938 """
939- # Caching would help here, especially when repeating the operation
940- # over different frames, since these masks are coordinate-independent.
941- # However, cache must be invalidated whenever new compound indices are
942- # modified, which is not yet implemented.
943- # Also, should we include here the grouping for 'group', which is
939+ # Should we include here the grouping for 'group', which is
944940 # essentially a non-split?
945941
942+ cache_key = f"{ compound } _masks"
946943 compound_indices = self ._get_compound_indices (compound )
947- compound_sizes = np .bincount (compound_indices )
948- size_per_atom = compound_sizes [compound_indices ]
949- compound_sizes = compound_sizes [compound_sizes != 0 ]
950- unique_compound_sizes = unique_int_1d (compound_sizes )
951-
952- # Are we already sorted? argsorting and fancy-indexing can be expensive
953- # so we do a quick pre-check.
954- needs_sorting = np .any (np .diff (compound_indices ) < 0 )
955- if needs_sorting :
956- # stable sort ensures reproducibility, especially concerning who
957- # gets to be a compound's atom[0] and be a reference for unwrap.
958- if stable_sort :
959- sort_indices = np .argsort (compound_indices , kind = 'stable' )
960- else :
961- # Quicksort
962- sort_indices = np .argsort (compound_indices )
963- # We must sort size_per_atom accordingly (Issue #3352).
964- size_per_atom = size_per_atom [sort_indices ]
965-
966- compound_masks = []
967- atom_masks = []
968- for compound_size in unique_compound_sizes :
969- compound_masks .append (compound_sizes == compound_size )
944+
945+ # create new cache or invalidate cache when compound indices changed
946+ if (
947+ cache_key not in self ._cache
948+ or np .all (self ._cache [cache_key ]["compound_indices" ]
949+ != compound_indices )):
950+ compound_sizes = np .bincount (compound_indices )
951+ size_per_atom = compound_sizes [compound_indices ]
952+ compound_sizes = compound_sizes [compound_sizes != 0 ]
953+ unique_compound_sizes = unique_int_1d (compound_sizes )
954+
955+ # Are we already sorted? argsorting and fancy-indexing can be
956+ # expensive so we do a quick pre-check.
957+ needs_sorting = np .any (np .diff (compound_indices ) < 0 )
970958 if needs_sorting :
971- atom_masks .append (sort_indices [size_per_atom == compound_size ]
972- .reshape (- 1 , compound_size ))
973- else :
974- atom_masks .append (np .where (size_per_atom == compound_size )[0 ]
975- .reshape (- 1 , compound_size ))
959+ # stable sort ensures reproducibility, especially concerning
960+ # who gets to be a compound's atom[0] and be a reference for
961+ # unwrap.
962+ if stable_sort :
963+ sort_indices = np .argsort (compound_indices , kind = 'stable' )
964+ else :
965+ # Quicksort
966+ sort_indices = np .argsort (compound_indices )
967+ # We must sort size_per_atom accordingly (Issue #3352).
968+ size_per_atom = size_per_atom [sort_indices ]
969+
970+ compound_masks = []
971+ atom_masks = []
972+ for compound_size in unique_compound_sizes :
973+ compound_masks .append (compound_sizes == compound_size )
974+ if needs_sorting :
975+ atom_masks .append (sort_indices [size_per_atom
976+ == compound_size ]
977+ .reshape (- 1 , compound_size ))
978+ else :
979+ atom_masks .append (np .where (size_per_atom
980+ == compound_size )[0 ]
981+ .reshape (- 1 , compound_size ))
982+
983+ self ._cache [cache_key ] = {
984+ "compound_indices" : compound_indices ,
985+ "data" : (atom_masks , compound_masks , len (compound_sizes ))
986+ }
976987
977- return atom_masks , compound_masks , len ( compound_sizes )
988+ return self . _cache [ cache_key ][ "data" ]
978989
979990 @warn_if_not_unique
980991 @_pbc_to_wrap
@@ -3200,7 +3211,7 @@ def select_atoms(self, sel, *othersel, periodic=True, rtol=1e-05,
32003211 universe = mda.Universe(PSF, DCD)
32013212 guessed_elements = guess_types(universe.atoms.names)
32023213 universe.add_TopologyAttr('elements', guessed_elements)
3203-
3214+
32043215 .. doctest:: AtomGroup.select_atoms.smarts
32053216
32063217 >>> universe.select_atoms("smarts C", smarts_kwargs={"maxMatches": 100})
0 commit comments