Skip to content

Commit 9eb180b

Browse files
ilan-golddcherianpre-commit-ci[bot]
authored
(feat): Support for pandas ExtensionArray (#8723)
* (feat): first pass supporting extension arrays * (feat): categorical tests + functionality * (feat): use multiple dispatch for unimplemented ops * (feat): implement (not really) broadcasting * (chore): add more `groupby` tests * (fix): fix more groupby incompatibility * (bug): fix unused categories * (chore): refactor dispatched methods + tests * (fix): shared type should check for extension arrays first and then fall back to numpy * (refactor): tests moved * (chore): more higher level tests * (feat): to/from dataframe * (chore): check for plum import * (fix): `__setitem__`/`__getitem__` * (chore): disallow stacking * (fix): `pyproject.toml` * (fix): `as_shared_type` fix * (chore): add variable tests * (fix): dask + categoricals * (chore): notes/docs * (chore): remove old testing file * (chore): remove ocmmented out code * (fix): import plum dispatch * (refactor): use `is_extension_array_dtype` as much as possible * (refactor): `extension_array`->`array` + move to `indexing` * (refactor): change order of classes * (chore): add small pyarrow test * (fix): fix some mypy issues * (fix): don't register unregisterable method * (fix): appease mypy * (fix): more sensible default implemetations allow most use without `plum` * (fix): handling `pyarrow` tests * (fix): actually do import correctly * (fix): `reduce` condition * (fix): column ordering for dataframes * (refactor): remove encoding business * (refactor): raise error for dask + extension array * (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array * (fix): use duck array equality method, not pandas * (refactor): bye plum! * (fix): `and` to `or` for casting to `ExtensionDuckArray` * (fix): check for class, not type * (fix): only support native endianness * (refactor): no need for superfluous checks in `_maybe_wrap_data` * (chore): clean up docs to no longer reference `plum` * (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray` * (refactor): move `implements` logic to `indexing` * (refactor): `indexing.py` -> `extension_array.py` * (refactor): `ExtensionDuckArray` -> `PandasExtensionArray` * (fix): add writeable property * (fix): don't check writeable for `PandasExtensionArray` * (fix): move check eariler * (refactor): correct guard clause * (chore): remove unnecessary `AttributeError` * (feat): singleton wrapped as array * (feat): remove shared dtype casting * (feat): loop once over `dataframe.items` * (feat): add `__len__` attribute * (fix): ensure constructor recieves `pd.Categorical` * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian <[email protected]> * (fix): drop condition for categorical corrected * Apply suggestions from code review * (chore): test `chunk` behavior * Update xarray/core/variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (fix): bring back error * (chore): add test for dropping cat for mean * Update whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 60f3e74 commit 9eb180b

16 files changed

+434
-43
lines changed

doc/whats-new.rst

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ New Features
2424
~~~~~~~~~~~~
2525
- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
2626
By `Pascal Bourgault <https://github.com/aulemahal>`_.
27-
27+
- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
28+
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
29+
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
30+
then, such as broadcasting.
31+
By `Ilan Gold <https://github.com/ilan-gold>`_.
2832

2933
Breaking changes
3034
~~~~~~~~~~~~~~~~

properties/test_pandas_roundtrip.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from hypothesis import given # isort:skip
1818

1919
numeric_dtypes = st.one_of(
20-
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
20+
npst.unsigned_integer_dtypes(endianness="="),
21+
npst.integer_dtypes(endianness="="),
22+
npst.floating_dtypes(endianness="="),
2123
)
2224

2325
numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ module = [
130130
"opt_einsum.*",
131131
"pandas.*",
132132
"pooch.*",
133+
"pyarrow.*",
133134
"pydap.*",
134135
"pytest.*",
135136
"scipy.*",

xarray/core/dataset.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload
2525

2626
import numpy as np
27+
from pandas.api.types import is_extension_array_dtype
2728

2829
# remove once numpy 2.0 is the oldest supported version
2930
try:
@@ -6852,10 +6853,13 @@ def reduce(
68526853
if (
68536854
# Some reduction functions (e.g. std, var) need to run on variables
68546855
# that don't have the reduce dims: PR5393
6855-
not reduce_dims
6856-
or not numeric_only
6857-
or np.issubdtype(var.dtype, np.number)
6858-
or (var.dtype == np.bool_)
6856+
not is_extension_array_dtype(var.dtype)
6857+
and (
6858+
not reduce_dims
6859+
or not numeric_only
6860+
or np.issubdtype(var.dtype, np.number)
6861+
or (var.dtype == np.bool_)
6862+
)
68596863
):
68606864
# prefer to aggregate over axis=None rather than
68616865
# axis=(0, 1) if they will be equivalent, because
@@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
71687172
)
71697173

71707174
def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
7171-
columns = [k for k in self.variables if k not in self.dims]
7175+
columns_in_order = [k for k in self.variables if k not in self.dims]
7176+
non_extension_array_columns = [
7177+
k
7178+
for k in columns_in_order
7179+
if not is_extension_array_dtype(self.variables[k].data)
7180+
]
7181+
extension_array_columns = [
7182+
k
7183+
for k in columns_in_order
7184+
if is_extension_array_dtype(self.variables[k].data)
7185+
]
71727186
data = [
71737187
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
7174-
for k in columns
7188+
for k in non_extension_array_columns
71757189
]
71767190
index = self.coords.to_index([*ordered_dims])
7177-
return pd.DataFrame(dict(zip(columns, data)), index=index)
7191+
broadcasted_df = pd.DataFrame(
7192+
dict(zip(non_extension_array_columns, data)), index=index
7193+
)
7194+
for extension_array_column in extension_array_columns:
7195+
extension_array = self.variables[extension_array_column].data.array
7196+
index = self[self.variables[extension_array_column].dims[0]].data
7197+
extension_array_df = pd.DataFrame(
7198+
{extension_array_column: extension_array},
7199+
index=self[self.variables[extension_array_column].dims[0]].data,
7200+
)
7201+
extension_array_df.index.name = self.variables[extension_array_column].dims[
7202+
0
7203+
]
7204+
broadcasted_df = broadcasted_df.join(extension_array_df)
7205+
return broadcasted_df[columns_in_order]
71787206

71797207
def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
71807208
"""Convert this dataset into a pandas.DataFrame.
@@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73217349
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
73227350
)
73237351

7324-
# Cast to a NumPy array first, in case the Series is a pandas Extension
7325-
# array (which doesn't have a valid NumPy dtype)
7326-
# TODO: allow users to control how this casting happens, e.g., by
7327-
# forwarding arguments to pandas.Series.to_numpy?
7328-
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
7352+
arrays = []
7353+
extension_arrays = []
7354+
for k, v in dataframe.items():
7355+
if not is_extension_array_dtype(v):
7356+
arrays.append((k, np.asarray(v)))
7357+
else:
7358+
extension_arrays.append((k, v))
73297359

73307360
indexes: dict[Hashable, Index] = {}
73317361
index_vars: dict[Hashable, Variable] = {}
@@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73397369
xr_idx = PandasIndex(lev, dim)
73407370
indexes[dim] = xr_idx
73417371
index_vars.update(xr_idx.create_variables())
7372+
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
7373+
extension_arrays = []
73427374
else:
73437375
index_name = idx.name if idx.name is not None else "index"
73447376
dims = (index_name,)
@@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73527384
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
73537385
else:
73547386
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
7355-
return obj
7387+
for name, extension_array in extension_arrays:
7388+
obj[name] = (dims, extension_array)
7389+
return obj[dataframe.columns] if len(dataframe.columns) else obj
73567390

73577391
def to_dask_dataframe(
73587392
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False

xarray/core/duck_array_ops.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from numpy import concatenate as _concatenate
3333
from numpy.lib.stride_tricks import sliding_window_view # noqa
3434
from packaging.version import Version
35+
from pandas.api.types import is_extension_array_dtype
3536

3637
from xarray.core import dask_array_ops, dtypes, nputils
3738
from xarray.core.options import OPTIONS
@@ -156,7 +157,7 @@ def isnull(data):
156157
return full_like(data, dtype=bool, fill_value=False)
157158
else:
158159
# at this point, array should have dtype=object
159-
if isinstance(data, np.ndarray):
160+
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
160161
return pandas_isnull(data)
161162
else:
162163
# Not reachable yet, but intended for use with other duck array
@@ -221,9 +222,19 @@ def asarray(data, xp=np):
221222

222223
def as_shared_dtype(scalars_or_arrays, xp=np):
223224
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
224-
array_type_cupy = array_type("cupy")
225-
if array_type_cupy and any(
226-
isinstance(x, array_type_cupy) for x in scalars_or_arrays
225+
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
226+
extension_array_types = [
227+
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
228+
]
229+
if len(extension_array_types) == len(scalars_or_arrays) and all(
230+
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
231+
):
232+
return scalars_or_arrays
233+
raise ValueError(
234+
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
235+
)
236+
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
237+
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
227238
):
228239
import cupy as cp
229240

xarray/core/extension_array.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from typing import Callable, Generic
5+
6+
import numpy as np
7+
import pandas as pd
8+
from pandas.api.types import is_extension_array_dtype
9+
10+
from xarray.core.types import DTypeLikeSave, T_ExtensionArray
11+
12+
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}
13+
14+
15+
def implements(numpy_function):
16+
"""Register an __array_function__ implementation for MyArray objects."""
17+
18+
def decorator(func):
19+
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
20+
return func
21+
22+
return decorator
23+
24+
25+
@implements(np.issubdtype)
26+
def __extension_duck_array__issubdtype(
27+
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
28+
) -> bool:
29+
return False # never want a function to think a pandas extension dtype is a subtype of numpy
30+
31+
32+
@implements(np.broadcast_to)
33+
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
34+
if shape[0] == len(arr) and len(shape) == 1:
35+
return arr
36+
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")
37+
38+
39+
@implements(np.stack)
40+
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
41+
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")
42+
43+
44+
@implements(np.concatenate)
45+
def __extension_duck_array__concatenate(
46+
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
47+
) -> T_ExtensionArray:
48+
return type(arrays[0])._concat_same_type(arrays)
49+
50+
51+
@implements(np.where)
52+
def __extension_duck_array__where(
53+
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
54+
) -> T_ExtensionArray:
55+
if (
56+
isinstance(x, pd.Categorical)
57+
and isinstance(y, pd.Categorical)
58+
and x.dtype != y.dtype
59+
):
60+
x = x.add_categories(set(y.categories).difference(set(x.categories)))
61+
y = y.add_categories(set(x.categories).difference(set(y.categories)))
62+
return pd.Series(x).where(condition, pd.Series(y)).array
63+
64+
65+
class PandasExtensionArray(Generic[T_ExtensionArray]):
66+
array: T_ExtensionArray
67+
68+
def __init__(self, array: T_ExtensionArray):
69+
"""NEP-18 compliant wrapper for pandas extension arrays.
70+
71+
Parameters
72+
----------
73+
array : T_ExtensionArray
74+
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
75+
```
76+
"""
77+
if not isinstance(array, pd.api.extensions.ExtensionArray):
78+
raise TypeError(f"{array} is not an pandas ExtensionArray.")
79+
self.array = array
80+
81+
def __array_function__(self, func, types, args, kwargs):
82+
def replace_duck_with_extension_array(args) -> list:
83+
args_as_list = list(args)
84+
for index, value in enumerate(args_as_list):
85+
if isinstance(value, PandasExtensionArray):
86+
args_as_list[index] = value.array
87+
elif isinstance(
88+
value, tuple
89+
): # should handle more than just tuple? iterable?
90+
args_as_list[index] = tuple(
91+
replace_duck_with_extension_array(value)
92+
)
93+
elif isinstance(value, list):
94+
args_as_list[index] = replace_duck_with_extension_array(value)
95+
return args_as_list
96+
97+
args = tuple(replace_duck_with_extension_array(args))
98+
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
99+
return func(*args, **kwargs)
100+
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
101+
if is_extension_array_dtype(res):
102+
return type(self)[type(res)](res)
103+
return res
104+
105+
def __array_ufunc__(ufunc, method, *inputs, **kwargs):
106+
return ufunc(*inputs, **kwargs)
107+
108+
def __repr__(self):
109+
return f"{type(self)}(array={repr(self.array)})"
110+
111+
def __getattr__(self, attr: str) -> object:
112+
return getattr(self.array, attr)
113+
114+
def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
115+
item = self.array[key]
116+
if is_extension_array_dtype(item):
117+
return type(self)(item)
118+
if np.isscalar(item):
119+
return type(self)(type(self.array)([item]))
120+
return item
121+
122+
def __setitem__(self, key, val):
123+
self.array[key] = val
124+
125+
def __eq__(self, other):
126+
if np.isscalar(other):
127+
other = type(self)(type(self.array)([other]))
128+
if isinstance(other, PandasExtensionArray):
129+
return self.array == other.array
130+
return self.array == other
131+
132+
def __ne__(self, other):
133+
return ~(self == other)
134+
135+
def __len__(self):
136+
return len(self.array)

xarray/core/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ def copy(
167167
# hopefully in the future we can narrow this down more:
168168
T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True)
169169

170+
# For typing pandas extension arrays.
171+
T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray)
172+
170173

171174
ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
172175
VarCompatible = Union["Variable", "ScalarOrArray"]

xarray/core/variable.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import numpy as np
1414
import pandas as pd
1515
from numpy.typing import ArrayLike
16+
from pandas.api.types import is_extension_array_dtype
1617

1718
import xarray as xr # only for Dataset and DataArray
1819
from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
1920
from xarray.core.arithmetic import VariableArithmetic
2021
from xarray.core.common import AbstractArray
22+
from xarray.core.extension_array import PandasExtensionArray
2123
from xarray.core.indexing import (
2224
BasicIndexer,
2325
OuterIndexer,
@@ -47,6 +49,7 @@
4749
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
4850
indexing.ExplicitlyIndexed,
4951
pd.Index,
52+
pd.api.extensions.ExtensionArray,
5053
)
5154
# https://github.com/python/mypy/issues/224
5255
BASIC_INDEXING_TYPES = integer_types + (slice,)
@@ -184,6 +187,8 @@ def _maybe_wrap_data(data):
184187
"""
185188
if isinstance(data, pd.Index):
186189
return PandasIndexingAdapter(data)
190+
if isinstance(data, pd.api.extensions.ExtensionArray):
191+
return PandasExtensionArray[type(data)](data)
187192
return data
188193

189194

@@ -2570,6 +2575,11 @@ def chunk( # type: ignore[override]
25702575
dask.array.from_array
25712576
"""
25722577

2578+
if is_extension_array_dtype(self):
2579+
raise ValueError(
2580+
f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first."
2581+
)
2582+
25732583
if from_array_kwargs is None:
25742584
from_array_kwargs = {}
25752585

xarray/testing/strategies.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
4545
Generates only those numpy dtypes which xarray can handle.
4646
4747
Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes.
48-
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows.
48+
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness.
4949
5050
Requires the hypothesis package to be installed.
5151
@@ -56,10 +56,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
5656
# TODO should this be exposed publicly?
5757
# We should at least decide what the set of numpy dtypes that xarray officially supports is.
5858
return (
59-
npst.integer_dtypes()
60-
| npst.unsigned_integer_dtypes()
61-
| npst.floating_dtypes()
62-
| npst.complex_number_dtypes()
59+
npst.integer_dtypes(endianness="=")
60+
| npst.unsigned_integer_dtypes(endianness="=")
61+
| npst.floating_dtypes(endianness="=")
62+
| npst.complex_number_dtypes(endianness="=")
6363
# | npst.datetime64_dtypes()
6464
# | npst.timedelta64_dtypes()
6565
# | npst.unicode_string_dtypes()

0 commit comments

Comments
 (0)