@@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
936936        n_compounds : int 
937937            The number of individual compounds. 
938938        """ 
939-         # Caching would help here, especially when repeating the operation 
940-         # over different frames, since these masks are coordinate-independent. 
941-         # However, cache must be invalidated whenever new compound indices are 
942-         # modified, which is not yet implemented. 
943-         # Also, should we include here the grouping for 'group', which is 
939+         # Should we include here the grouping for 'group', which is 
944940        # essentially a non-split? 
945941
942+         cache_key  =  f"{ compound }  _masks" 
946943        compound_indices  =  self ._get_compound_indices (compound )
947-         compound_sizes  =  np .bincount (compound_indices )
948-         size_per_atom  =  compound_sizes [compound_indices ]
949-         compound_sizes  =  compound_sizes [compound_sizes  !=  0 ]
950-         unique_compound_sizes  =  unique_int_1d (compound_sizes )
951- 
952-         # Are we already sorted? argsorting and fancy-indexing can be expensive 
953-         # so we do a quick pre-check. 
954-         needs_sorting  =  np .any (np .diff (compound_indices ) <  0 )
955-         if  needs_sorting :
956-             # stable sort ensures reproducibility, especially concerning who 
957-             # gets to be a compound's atom[0] and be a reference for unwrap. 
958-             if  stable_sort :
959-                 sort_indices  =  np .argsort (compound_indices , kind = 'stable' )
960-             else :
961-                 # Quicksort 
962-                 sort_indices  =  np .argsort (compound_indices )
963-             # We must sort size_per_atom accordingly (Issue #3352). 
964-             size_per_atom  =  size_per_atom [sort_indices ]
965- 
966-         compound_masks  =  []
967-         atom_masks  =  []
968-         for  compound_size  in  unique_compound_sizes :
969-             compound_masks .append (compound_sizes  ==  compound_size )
944+ 
945+         # create new cache or invalidate cache when compound indices changed 
946+         if  (
947+             cache_key  not  in   self ._cache 
948+             or  np .all (self ._cache [cache_key ]["compound_indices" ]
949+                       !=  compound_indices )):
950+             compound_sizes  =  np .bincount (compound_indices )
951+             size_per_atom  =  compound_sizes [compound_indices ]
952+             compound_sizes  =  compound_sizes [compound_sizes  !=  0 ]
953+             unique_compound_sizes  =  unique_int_1d (compound_sizes )
954+ 
955+             # Are we already sorted? argsorting and fancy-indexing can be 
956+             # expensive so we do a quick pre-check. 
957+             needs_sorting  =  np .any (np .diff (compound_indices ) <  0 )
970958            if  needs_sorting :
971-                 atom_masks .append (sort_indices [size_per_atom  ==  compound_size ]
972-                                    .reshape (- 1 , compound_size ))
973-             else :
974-                 atom_masks .append (np .where (size_per_atom  ==  compound_size )[0 ]
975-                                    .reshape (- 1 , compound_size ))
959+                 # stable sort ensures reproducibility, especially concerning 
960+                 # who gets to be a compound's atom[0] and be a reference for 
961+                 # unwrap. 
962+                 if  stable_sort :
963+                     sort_indices  =  np .argsort (compound_indices , kind = 'stable' )
964+                 else :
965+                     # Quicksort 
966+                     sort_indices  =  np .argsort (compound_indices )
967+                 # We must sort size_per_atom accordingly (Issue #3352). 
968+                 size_per_atom  =  size_per_atom [sort_indices ]
969+ 
970+             compound_masks  =  []
971+             atom_masks  =  []
972+             for  compound_size  in  unique_compound_sizes :
973+                 compound_masks .append (compound_sizes  ==  compound_size )
974+                 if  needs_sorting :
975+                     atom_masks .append (sort_indices [size_per_atom 
976+                                                    ==  compound_size ]
977+                                       .reshape (- 1 , compound_size ))
978+                 else :
979+                     atom_masks .append (np .where (size_per_atom 
980+                                                ==  compound_size )[0 ]
981+                                       .reshape (- 1 , compound_size ))
982+ 
983+             self ._cache [cache_key ] =  {
984+                 "compound_indices" : compound_indices ,
985+                 "data" : (atom_masks , compound_masks , len (compound_sizes ))
986+             }
976987
977-         return  atom_masks ,  compound_masks ,  len ( compound_sizes ) 
988+         return  self . _cache [ cache_key ][ "data" ] 
978989
979990    @warn_if_not_unique  
980991    @_pbc_to_wrap  
@@ -3200,7 +3211,7 @@ def select_atoms(self, sel, *othersel, periodic=True, rtol=1e-05,
32003211                    universe = mda.Universe(PSF, DCD) 
32013212                    guessed_elements = guess_types(universe.atoms.names) 
32023213                    universe.add_TopologyAttr('elements', guessed_elements) 
3203-                      
3214+ 
32043215                .. doctest:: AtomGroup.select_atoms.smarts 
32053216
32063217                    >>> universe.select_atoms("smarts C", smarts_kwargs={"maxMatches": 100}) 
0 commit comments