diff --git a/anndata/base.py b/anndata/base.py index db72cae99..7d2206117 100644 --- a/anndata/base.py +++ b/anndata/base.py @@ -215,6 +215,7 @@ def _fix_shapes(X): def _normalize_index(index, names): + assert names.dtype != float and names.dtype != int, 'Don’t call _normalize_index with non-categorical/string names' # the following is insanely slow for sequences, we replaced it using pandas below def name_idx(i): if isinstance(i, str): @@ -632,6 +633,12 @@ def __init__( filename=filename, filemode=filemode) def _init_as_view(self, adata_ref, oidx, vidx): + def get_n_items_idx(idx): + if isinstance(idx, np.ndarray) and idx.dtype == bool: + return idx.sum() + else: + return len(idx) + self._isview = True self._adata_ref = adata_ref self._oidx = oidx @@ -653,19 +660,19 @@ def _init_as_view(self, adata_ref, oidx, vidx): self._slice_uns_sparse_matrices_inplace(uns_new, self._oidx) # fix _n_obs, _n_vars if isinstance(oidx, slice): - self._n_obs = len(obs_sub.index) + self._n_obs = get_n_items_idx(obs_sub.index) elif isinstance(oidx, (int, np.int64)): self._n_obs = 1 elif isinstance(oidx, Sized): - self._n_obs = len(oidx) + self._n_obs = get_n_items_idx(oidx) else: raise KeyError('Unknown Index type') if isinstance(vidx, slice): - self._n_vars = len(var_sub.index) + self._n_vars = get_n_items_idx(var_sub.index) elif isinstance(vidx, (int, np.int64)): self._n_vars = 1 elif isinstance(vidx, Sized): - self._n_vars = len(vidx) + self._n_vars = get_n_items_idx(vidx) else: raise KeyError('Unknown Index type') # fix categories @@ -1154,14 +1161,17 @@ def _normalize_indices(self, index: Union[Tuple, pd.Series, np.ndarray, slice, i index = index[0], index[1].values if isinstance(index[0], pd.Series): index = index[0].values, index[1] - # one of the two has to be a slice - if not (isinstance(index[0], slice) or isinstance(index[1], slice)): - if isinstance(index[0], (int, str, None)) and isinstance(index[1], (int, str, None)): - pass # two scalars are fine - else: - raise NotImplementedError( - 'Slicing with two indices at the same time is not yet implemented. ' - 'As a workaround, do row and column slicing succesively.') + + no_slice = not any(isinstance(i, slice) for i in index) + both_scalars = all(isinstance(i, (int, str, type(None))) for i in index) + if no_slice and not both_scalars: + raise NotImplementedError( + 'Slicing with two indices at the same time is not yet implemented. ' + 'As a workaround, do row and column slicing succesively.') + # Speed up and error prevention for boolean indices (Don’t convert to integer indices) + # Needs to be refactored once we support a tuple of two arbitrary index types + if any(isinstance(i, np.ndarray) and i.dtype == bool for i in index): + return index obs, var = super(AnnData, self)._unpack_index(index) obs = _normalize_index(obs, self.obs_names) var = _normalize_index(var, self.var_names) diff --git a/anndata/tests/base.py b/anndata/tests/base.py index e5eb29bb8..cd08dac10 100644 --- a/anndata/tests/base.py +++ b/anndata/tests/base.py @@ -135,6 +135,14 @@ def test_slicing_remove_unused_categories(): assert adata[3:5].obs['k'].cat.categories.tolist() == ['b'] +def test_slicing_integer_index(): + adata = AnnData( + np.array([[0, 1, 2], [3, 4, 5]]), + var=dict(var_names=[10, 11, 12])) + sliced = adata[:, adata.X.sum(0) > 3] # This used to fail + assert sliced.shape == (2, 2) + + def test_get_subset_annotation(): adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]), dict(S=['A', 'B']),