diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f41e0eea8cb..c0e6ba5603e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6917,7 +6917,7 @@ def groupby( [[nan, nan, nan], [ 3., 4., 5.]]]) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 16B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf2858c1b18..f646146a04b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7053,6 +7053,8 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): + from xarray.core.extension_array import PandasExtensionArray + columns_in_order = [k for k in self.variables if k not in self.dims] non_extension_array_columns = [ k @@ -7077,7 +7079,7 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): index = self[self.variables[extension_array_column].dims[0]].data extension_array_df = pd.DataFrame( {extension_array_column: extension_array}, - index=self[self.variables[extension_array_column].dims[0]].data, + index=index.array if isinstance(index, PandasExtensionArray) else index, ) extension_array_df.index.name = self.variables[extension_array_column].dims[ 0 @@ -9889,7 +9891,7 @@ def groupby( Size: 128B Dimensions: (y: 3, x_bins: 2, letters: 2) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 16B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y Data variables: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 43829b4029f..e8006a4c8c3 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -102,7 +102,7 @@ def replace_duck_with_extension_array(args) -> list: return type(self)[type(res)](res) return res - def __array_ufunc__(ufunc, method, *inputs, **kwargs): + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __repr__(self): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 993cddf2b57..959e6a5c643 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -334,7 +334,10 @@ def summarize_variable( else: dims_str = "" - nbytes_str = f" {render_human_readable_nbytes(variable.nbytes)}" + try: + nbytes_str = f" {render_human_readable_nbytes(variable.nbytes)}" + except TypeError: + nbytes_str = " ?" front_str = f"{first_col}{dims_str}{variable.dtype}{nbytes_str} " values_width = max_width - len(front_str) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0b4eee7b21c..4194fd506a7 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -11,6 +11,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( CoordinateTransformIndexingAdapter, IndexSelResult, @@ -444,6 +445,8 @@ def safe_cast_to_index(array: Any) -> pd.Index: from xarray.core.variable import Variable from xarray.namedarray.pycompat import to_numpy + if isinstance(array, PandasExtensionArray): + array = pd.Index(array.array) if isinstance(array, pd.Index): index = array elif isinstance(array, DataArray | Variable): @@ -602,7 +605,11 @@ def __init__( self.dim = dim if coord_dtype is None: - coord_dtype = get_valid_numpy_dtype(index) + if pd.api.types.is_extension_array_dtype(index.dtype): + cast(pd.api.extensions.ExtensionDtype, index.dtype) + coord_dtype = index.dtype + else: + coord_dtype = get_valid_numpy_dtype(index) self.coord_dtype = coord_dtype def _replace(self, index, dim=None, coord_dtype=None): @@ -698,6 +705,8 @@ def concat( if not indexes: coord_dtype = None + elif len(set(idx.coord_dtype for idx in indexes)) == 1: + coord_dtype = indexes[0].coord_dtype else: coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 70d6974f5ec..0f5829356dd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -10,14 +10,16 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload import numpy as np import pandas as pd +from numpy.typing import DTypeLike from packaging.version import Version from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -28,14 +30,13 @@ is_duck_array, is_duck_dask_array, is_scalar, + is_valid_numpy_dtype, to_0d_array, ) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: - from numpy.typing import DTypeLike - from xarray.core.indexes import Index from xarray.core.types import Self from xarray.core.variable import Variable @@ -1744,27 +1745,43 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("_dtype", "array") array: pd.Index - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__( + self, + array: pd.Index, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + ): from xarray.core.indexes import safe_cast_to_index self.array = safe_cast_to_index(array) if dtype is None: - self._dtype = get_valid_numpy_dtype(array) + if pd.api.types.is_extension_array_dtype(array.dtype): + cast(pd.api.extensions.ExtensionDtype, array.dtype) + self._dtype = array.dtype + else: + self._dtype = get_valid_numpy_dtype(array) + elif pd.api.types.is_extension_array_dtype(dtype): + self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: - self._dtype = np.dtype(dtype) + self._dtype = np.dtype(cast(DTypeLike, dtype)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override] return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: - if dtype is None: - dtype = self.dtype + if dtype is None and is_valid_numpy_dtype(self.dtype): + dtype = cast(np.dtype, self.dtype) + else: + dtype = get_valid_numpy_dtype(self.array) array = self.array if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): @@ -1776,14 +1793,18 @@ def __array__( else: return np.asarray(array.values, dtype=dtype) - def get_duck_array(self) -> np.ndarray: + def get_duck_array(self) -> np.ndarray | PandasExtensionArray: + # We return an PandasExtensionArray wrapper type that satisfies + # duck array protocols. This is what's needed for tests to pass. + if pd.api.types.is_extension_array_dtype(self.array): + return PandasExtensionArray(self.array.array) return np.asarray(self) @property def shape(self) -> _Shape: return (len(self.array),) - def _convert_scalar(self, item): + def _convert_scalar(self, item) -> np.ndarray: if item is pd.NaT: # work around the impossibility of casting NaT with asarray # note: it probably would be better in general to return @@ -1799,7 +1820,10 @@ def _convert_scalar(self, item): # numpy fails to convert pd.Timestamp to np.datetime64[ns] item = np.asarray(item.to_datetime64()) elif self.dtype != object: - item = np.asarray(item, dtype=self.dtype) + dtype = self.dtype + if pd.api.types.is_extension_array_dtype(dtype): + dtype = get_valid_numpy_dtype(self.array) + item = np.asarray(item, dtype=cast(np.dtype, dtype)) # as for numpy.ndarray indexing, we always want the result to be # a NumPy array. @@ -1914,23 +1938,27 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): __slots__ = ("_dtype", "adapter", "array", "level") array: pd.MultiIndex - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype level: str | None def __init__( self, array: pd.MultiIndex, - dtype: DTypeLike = None, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, level: str | None = None, ): super().__init__(array, dtype) self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: - dtype = self.dtype + dtype = cast(np.dtype, self.dtype) if self.level is not None: return np.asarray( self.array.get_level_values(self.level).values, dtype=dtype diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d5b0b2187bb..6fbd3048837 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -163,16 +163,6 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: Used for wrapping a pandas.Index as an xarray.Variable. """ - if isinstance(array, pd.PeriodIndex): - return np.dtype("O") - - if hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype = array.categories.dtype - if not is_valid_numpy_dtype(dtype): - dtype = np.dtype("O") - return dtype - if not is_valid_numpy_dtype(array.dtype): return np.dtype("O") diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3d41a48644f..b3a5362b3b6 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.compat.array_api_compat import to_like_array @@ -60,6 +59,7 @@ indexing.ExplicitlyIndexed, pd.Index, pd.api.extensions.ExtensionArray, + PandasExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -192,7 +192,9 @@ def _maybe_wrap_data(data): if isinstance(data, pd.Index): return PandasIndexingAdapter(data) if isinstance(data, pd.api.extensions.ExtensionArray): - return PandasExtensionArray[type(data)](data) + if isinstance(data.dtype, pd.Int64Dtype | pd.Float64Dtype | pd.StringDtype): + return np.array(data) + return PandasExtensionArray(data) return data @@ -2593,11 +2595,6 @@ def chunk( # type: ignore[override] dask.array.from_array """ - if is_extension_array_dtype(self): - raise ValueError( - f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." - ) - if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index cdf9eab5c8d..f9c1919201f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -834,6 +834,7 @@ def chunk( if chunkmanager.is_chunked_array(data_old): data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] else: + ndata: duckarray[Any, Any] if not isinstance(data_old, ExplicitlyIndexed): ndata = data_old else: diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 99c06e4c41b..6b5995ca7aa 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -524,7 +524,7 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - if np.issubdtype(xplt.dtype, np.datetime64): + if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr] _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b273b7d1a0d..f825c7c7d12 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4353,8 +4353,13 @@ def test_setitem_pandas(self) -> None: ds = self.make_example_math_dataset() ds["x"] = np.arange(3) ds_copy = ds.copy() - ds_copy["bar"] = ds["bar"].to_pandas() - + series = ds["bar"].to_pandas() + # to_pandas will actually give the result where the internal array of the series is a NumpyExtensionArray + # but ds["bar"] is a numpy array. + # TODO: should assert_equal be updated to handle? + assert (ds["bar"] == series).all() + del ds["bar"] + del ds_copy["bar"] assert_equal(ds, ds_copy) def test_setitem_auto_align(self) -> None: @@ -4945,6 +4950,29 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_int_float_str_pandas_dtype(self) -> None: + df = pd.DataFrame([1, 2, 3], dtype=pd.Int64Dtype()) + ds = xr.Dataset(df) + assert isinstance(ds[0].data, np.ndarray) + + df = pd.DataFrame([1, 2, 3], dtype=pd.Float64Dtype()) + ds = xr.Dataset(df) + assert isinstance(ds[0].data, np.ndarray) + + df = pd.DataFrame(["1", "2", "3"], dtype=pd.StringDtype()) + ds = xr.Dataset(df) + assert isinstance(ds[0].data, np.ndarray) + + def test_from_dataframe_categorical_dtype_index(self) -> None: + cat = pd.CategoricalIndex(list("abcd")) + df = pd.DataFrame({"f": [0, 1, 2, 3]}, index=cat) + ds = df.to_xarray() + restored = ds.to_dataframe() + df.index.name = ( + "index" # restored gets the name because it has the coord with the name + ) + pd.testing.assert_frame_equal(df, restored) + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] @@ -4969,7 +4997,7 @@ def test_from_dataframe_categorical_index_string_categories(self) -> None: ) ser = pd.Series(1, index=cat) ds = ser.to_xarray() - assert ds.coords.dtypes["index"] == np.dtype("O") + assert ds.coords.dtypes["index"] == ser.index.dtype @requires_sparse def test_from_dataframe_sparse(self) -> None: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1c351f0ee62..c23830f6bc8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1118,7 +1118,8 @@ def test_groupby_math_nD_group() -> None: expected = da.isel(x=slice(30)) - expanded_mean expected["labels"] = expected.labels.broadcast_like(expected.labels2d) expected["num"] = expected.num.broadcast_like(expected.num2d) - expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr]) + # mean.num2d_bins.data is a pandas IntervalArray so needs to be put in `numpy` to allow indexing + expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data.to_numpy()[idxr]) actual = g - mean assert_identical(expected, actual) @@ -1680,13 +1681,9 @@ def test_groupby_bins( df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() - # TODO: can't convert df with IntervalIndex to Xarray - expected = ( - expected_df.reset_index(drop=True) - .to_xarray() - .assign_coords(index=np.array(expected_df.index)) - .rename({"index": "dim_0_bins"})["a"] - ) + expected = expected_df.to_xarray().assign_coords( + dim_0_bins=expected_df.index.categories + )["a"] with xr.set_options(use_flox=use_flox): gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) diff --git a/xarray/tests/test_pandas_to_xarray.py b/xarray/tests/test_pandas_to_xarray.py index 0ed20fc21a9..2dc9ffe48bb 100644 --- a/xarray/tests/test_pandas_to_xarray.py +++ b/xarray/tests/test_pandas_to_xarray.py @@ -104,17 +104,11 @@ def index_flat(request): index fixture, but excluding MultiIndex cases. """ key = request.param + if key in ["bool-object", "bool-dtype", "nullable_bool", "repeats"]: + pytest.xfail(reason="doesn't work") return indices_dict[key].copy() -@pytest.fixture -def using_infer_string() -> bool: - """ - Fixture to check if infer string option is enabled. - """ - return pd.options.future.infer_string is True # type: ignore[union-attr] - - class TestDataFrameToXArray: @pytest.fixture def df(self): @@ -131,8 +125,7 @@ def df(self): } ) - @pytest.mark.xfail(reason="needs some work") - def test_to_xarray_index_types(self, index_flat, df, using_infer_string): + def test_to_xarray_index_types(self, index_flat, df): index = index_flat # MultiIndex is tested in test_to_xarray_with_multiindex if len(index) == 0: @@ -154,9 +147,6 @@ def test_to_xarray_index_types(self, index_flat, df, using_infer_string): # datetimes w/tz are preserved # column names are lost expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) expected.columns.name = None tm.assert_frame_equal(result.to_dataframe(), expected) @@ -168,7 +158,7 @@ def test_to_xarray_empty(self, df): assert result.sizes["foo"] == 0 assert isinstance(result, Dataset) - def test_to_xarray_with_multiindex(self, df, using_infer_string): + def test_to_xarray_with_multiindex(self, df): from xarray import Dataset # MultiIndex @@ -183,9 +173,7 @@ def test_to_xarray_with_multiindex(self, df, using_infer_string): result = result.to_dataframe() expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) + expected["f"] = expected["f"].astype(object) expected.columns.name = None tm.assert_frame_equal(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 388f51bc568..b276e1021ad 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -333,7 +333,7 @@ def test_pandas_period_index(self): v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) v = v.load() # for dask-based Variable assert v[0] == pd.Period("2000", freq="D") - assert "Period('2000-01-01', 'D')" in repr(v) + assert "PeriodArray" in repr(v) @pytest.mark.parametrize("dtype", [float, int]) def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: @@ -656,7 +656,7 @@ def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error - assert v.dtype == "int64" + assert v.dtype == data.dtype def test_pandas_datetime64_with_tz(self): data = pd.date_range( @@ -667,9 +667,12 @@ def test_pandas_datetime64_with_tz(self): ) v = self.cls("x", data) print(v) # should not error - if "America/New_York" in str(data.dtype): - # pandas is new enough that it has datetime64 with timezone dtype - assert v.dtype == "object" + if v.dtype == np.dtype("O"): + import dask.array as da + + assert isinstance(v.data, da.Array) + else: + assert v.dtype == data.dtype def test_multiindex(self): idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) @@ -1592,14 +1595,6 @@ def test_pandas_categorical_dtype(self): print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) - def test_pandas_categorical_no_chunk(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - v = self.cls("x", data) - with pytest.raises( - ValueError, match=r".*was found to be a Pandas ExtensionArray.*" - ): - v.chunk((5,)) - def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2412,10 +2407,17 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_period_index(self): + super().test_pandas_period_index() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_datetime64_with_tz(self): + super().test_pandas_datetime64_with_tz() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") def test_pandas_categorical_dtype(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): - self.cls("x", data) + super().test_pandas_categorical_dtype() @requires_sparse @@ -3020,7 +3022,7 @@ def test_datetime_conversion(values, unit) -> None: # todo: check for redundancy (suggested per review) dims = ["time"] if isinstance(values, np.ndarray | pd.Index | pd.Series) else [] var = Variable(dims, values) - if var.dtype.kind == "M": + if var.dtype.kind == "M" and isinstance(var.dtype, np.dtype): assert var.dtype == np.dtype(f"datetime64[{unit}]") else: # The only case where a non-datetime64 dtype can occur currently is in @@ -3062,8 +3064,12 @@ def test_pandas_two_only_datetime_conversion_warnings( # todo: check for redundancy (suggested per review) var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] - if var.dtype.kind == "M": + # we internally convert series to numpy representations to avoid too much nastiness with extension arrays + # when calling data.array e.g., with NumpyExtensionArrays + if isinstance(data, pd.Series): assert var.dtype == np.dtype("datetime64[s]") + elif var.dtype.kind == "M": + assert var.dtype == dtype else: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware