24
24
from typing import IO , TYPE_CHECKING , Any , Callable , Generic , Literal , cast , overload
25
25
26
26
import numpy as np
27
+ from pandas .api .types import is_extension_array_dtype
27
28
28
29
# remove once numpy 2.0 is the oldest supported version
29
30
try :
@@ -6852,10 +6853,13 @@ def reduce(
6852
6853
if (
6853
6854
# Some reduction functions (e.g. std, var) need to run on variables
6854
6855
# that don't have the reduce dims: PR5393
6855
- not reduce_dims
6856
- or not numeric_only
6857
- or np .issubdtype (var .dtype , np .number )
6858
- or (var .dtype == np .bool_ )
6856
+ not is_extension_array_dtype (var .dtype )
6857
+ and (
6858
+ not reduce_dims
6859
+ or not numeric_only
6860
+ or np .issubdtype (var .dtype , np .number )
6861
+ or (var .dtype == np .bool_ )
6862
+ )
6859
6863
):
6860
6864
# prefer to aggregate over axis=None rather than
6861
6865
# axis=(0, 1) if they will be equivalent, because
@@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
7168
7172
)
7169
7173
7170
7174
def _to_dataframe (self , ordered_dims : Mapping [Any , int ]):
7171
- columns = [k for k in self .variables if k not in self .dims ]
7175
+ columns_in_order = [k for k in self .variables if k not in self .dims ]
7176
+ non_extension_array_columns = [
7177
+ k
7178
+ for k in columns_in_order
7179
+ if not is_extension_array_dtype (self .variables [k ].data )
7180
+ ]
7181
+ extension_array_columns = [
7182
+ k
7183
+ for k in columns_in_order
7184
+ if is_extension_array_dtype (self .variables [k ].data )
7185
+ ]
7172
7186
data = [
7173
7187
self ._variables [k ].set_dims (ordered_dims ).values .reshape (- 1 )
7174
- for k in columns
7188
+ for k in non_extension_array_columns
7175
7189
]
7176
7190
index = self .coords .to_index ([* ordered_dims ])
7177
- return pd .DataFrame (dict (zip (columns , data )), index = index )
7191
+ broadcasted_df = pd .DataFrame (
7192
+ dict (zip (non_extension_array_columns , data )), index = index
7193
+ )
7194
+ for extension_array_column in extension_array_columns :
7195
+ extension_array = self .variables [extension_array_column ].data .array
7196
+ index = self [self .variables [extension_array_column ].dims [0 ]].data
7197
+ extension_array_df = pd .DataFrame (
7198
+ {extension_array_column : extension_array },
7199
+ index = self [self .variables [extension_array_column ].dims [0 ]].data ,
7200
+ )
7201
+ extension_array_df .index .name = self .variables [extension_array_column ].dims [
7202
+ 0
7203
+ ]
7204
+ broadcasted_df = broadcasted_df .join (extension_array_df )
7205
+ return broadcasted_df [columns_in_order ]
7178
7206
7179
7207
def to_dataframe (self , dim_order : Sequence [Hashable ] | None = None ) -> pd .DataFrame :
7180
7208
"""Convert this dataset into a pandas.DataFrame.
@@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
7321
7349
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
7322
7350
)
7323
7351
7324
- # Cast to a NumPy array first, in case the Series is a pandas Extension
7325
- # array (which doesn't have a valid NumPy dtype)
7326
- # TODO: allow users to control how this casting happens, e.g., by
7327
- # forwarding arguments to pandas.Series.to_numpy?
7328
- arrays = [(k , np .asarray (v )) for k , v in dataframe .items ()]
7352
+ arrays = []
7353
+ extension_arrays = []
7354
+ for k , v in dataframe .items ():
7355
+ if not is_extension_array_dtype (v ):
7356
+ arrays .append ((k , np .asarray (v )))
7357
+ else :
7358
+ extension_arrays .append ((k , v ))
7329
7359
7330
7360
indexes : dict [Hashable , Index ] = {}
7331
7361
index_vars : dict [Hashable , Variable ] = {}
@@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
7339
7369
xr_idx = PandasIndex (lev , dim )
7340
7370
indexes [dim ] = xr_idx
7341
7371
index_vars .update (xr_idx .create_variables ())
7372
+ arrays += [(k , np .asarray (v )) for k , v in extension_arrays ]
7373
+ extension_arrays = []
7342
7374
else :
7343
7375
index_name = idx .name if idx .name is not None else "index"
7344
7376
dims = (index_name ,)
@@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
7352
7384
obj ._set_sparse_data_from_dataframe (idx , arrays , dims )
7353
7385
else :
7354
7386
obj ._set_numpy_data_from_dataframe (idx , arrays , dims )
7355
- return obj
7387
+ for name , extension_array in extension_arrays :
7388
+ obj [name ] = (dims , extension_array )
7389
+ return obj [dataframe .columns ] if len (dataframe .columns ) else obj
7356
7390
7357
7391
def to_dask_dataframe (
7358
7392
self , dim_order : Sequence [Hashable ] | None = None , set_index : bool = False
0 commit comments