From 6fb083477bc0b1f3eeccc62e10e4b477ae532346 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Tue, 21 May 2019 10:41:25 +0200 Subject: [PATCH] Use unpack_index from our code (#151) --- anndata/base.py | 11 +++++------ anndata/h5py/h5sparse.py | 8 ++++---- anndata/utils.py | 23 +++++++++++++++++++++-- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/anndata/base.py b/anndata/base.py index 44fe92035..3ebea1c1c 100644 --- a/anndata/base.py +++ b/anndata/base.py @@ -18,7 +18,6 @@ from pandas.api.types import is_string_dtype, is_categorical from scipy import sparse from scipy.sparse import issparse -from scipy.sparse.sputils import IndexMixin from natsort import natsorted # try importing zarr @@ -43,7 +42,7 @@ def __rep__(): from .layers import AnnDataLayers from . import utils -from .utils import Index, get_n_items_idx +from .utils import Index, get_n_items_idx, unpack_index from .logging import anndata_logger as logger from .compat import PathLike @@ -444,7 +443,7 @@ class DataFrameView(_ViewMixin, pd.DataFrame): _metadata = ['_view_args'] -class Raw(IndexMixin): +class Raw: def __init__( self, adata: Optional['AnnData'] = None, @@ -531,7 +530,7 @@ def _normalize_indices(self, packed_index): packed_index = packed_index[0], packed_index[1].values if isinstance(packed_index[0], pd.Series): packed_index = packed_index[0].values, packed_index[1] - obs, var = super()._unpack_index(packed_index) + obs, var = unpack_index(packed_index) obs = _normalize_index(obs, self._adata.obs_names) var = _normalize_index(var, self.var_names) return obs, var @@ -551,7 +550,7 @@ def __init__(self, n_dims): super().__init__(msg) -class AnnData(IndexMixin, metaclass=utils.DeprecationMixinMeta): +class AnnData(metaclass=utils.DeprecationMixinMeta): """An annotated data matrix. :class:`~anndata.AnnData` stores a data matrix :attr:`X` together with annotations @@ -1298,7 +1297,7 @@ def _normalize_indices(self, index: Optional[Index]): # Needs to be refactored once we support a tuple of two arbitrary index types if any(isinstance(i, np.ndarray) and i.dtype == bool for i in index): return index - obs, var = super()._unpack_index(index) + obs, var = unpack_index(index) obs = _normalize_index(obs, self.obs_names) var = _normalize_index(var, self.var_names) return obs, var diff --git a/anndata/h5py/h5sparse.py b/anndata/h5py/h5sparse.py index 5cc940f0e..73d2922b7 100644 --- a/anndata/h5py/h5sparse.py +++ b/anndata/h5py/h5sparse.py @@ -7,8 +7,8 @@ import h5py import numpy as np import scipy.sparse as ss -from scipy.sparse.sputils import IndexMixin +from ..utils import unpack_index from ..compat import PathLike from .utils import _chunked_rows @@ -236,7 +236,7 @@ def _zero_many(self, i, j): _cs_matrix._zero_many = _zero_many -class SparseDataset(IndexMixin): +class SparseDataset: """Analogous to :class:`h5py.Dataset `, but for sparse matrices. """ @@ -255,7 +255,7 @@ def format_str(self): def __getitem__(self, index): if index == (): index = slice(None) - row, col = self._unpack_index(index) + row, col = unpack_index(index) format_class = get_format_class(self.format_str) mock_matrix = format_class(self.shape, dtype=self.dtype) mock_matrix.data = self.h5py_group['data'] @@ -265,7 +265,7 @@ def __getitem__(self, index): def __setitem__(self, index, value): if index == (): index = slice(None) - row, col = self._unpack_index(index) + row, col = unpack_index(index) format_class = get_format_class(self.format_str) mock_matrix = format_class(self.shape, dtype=self.dtype) mock_matrix.data = self.h5py_group['data'] diff --git a/anndata/utils.py b/anndata/utils.py index 9cdb21aa2..01ce7e1b8 100644 --- a/anndata/utils.py +++ b/anndata/utils.py @@ -1,9 +1,10 @@ import warnings from functools import wraps -from typing import Mapping, Any, Sequence, Union, Sized, Optional +from typing import Mapping, Any, Sequence, Union, Tuple import pandas as pd import numpy as np +from scipy.sparse import spmatrix from .logging import get_logger if False: @@ -132,7 +133,7 @@ def is_deprecated(attr): ] -Index = Union[slice, int, np.int64, np.ndarray, Sized] +Index = Union[slice, int, np.int64, np.ndarray, spmatrix] def get_n_items_idx(idx: Index, l: int): @@ -147,3 +148,21 @@ def get_n_items_idx(idx: Index, l: int): return 1 else: return len(idx) + + +def unpack_index(index: Union[Index, Tuple[Index, Index]]) -> Tuple[Index, Index]: + # handle indexing with boolean matrices + if ( + isinstance(index, (spmatrix, np.ndarray)) + and index.ndim == 2 + and index.dtype.kind == 'b' + ): return index.nonzero() + + if not isinstance(index, tuple): + return index, slice(None) + elif len(index) == 2: + return index + elif len(index) == 1: + return index[0], slice(None) + else: + raise IndexError('invalid number of indices')