2424from typing import IO , TYPE_CHECKING , Any , Callable , Generic , Literal , cast , overload
2525
2626import numpy as np
27+ from pandas .api .types import is_extension_array_dtype
2728
2829# remove once numpy 2.0 is the oldest supported version
2930try :
@@ -6852,10 +6853,13 @@ def reduce(
68526853 if (
68536854 # Some reduction functions (e.g. std, var) need to run on variables
68546855 # that don't have the reduce dims: PR5393
6855- not reduce_dims
6856- or not numeric_only
6857- or np .issubdtype (var .dtype , np .number )
6858- or (var .dtype == np .bool_ )
6856+ not is_extension_array_dtype (var .dtype )
6857+ and (
6858+ not reduce_dims
6859+ or not numeric_only
6860+ or np .issubdtype (var .dtype , np .number )
6861+ or (var .dtype == np .bool_ )
6862+ )
68596863 ):
68606864 # prefer to aggregate over axis=None rather than
68616865 # axis=(0, 1) if they will be equivalent, because
@@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
71687172 )
71697173
71707174 def _to_dataframe (self , ordered_dims : Mapping [Any , int ]):
7171- columns = [k for k in self .variables if k not in self .dims ]
7175+ columns_in_order = [k for k in self .variables if k not in self .dims ]
7176+ non_extension_array_columns = [
7177+ k
7178+ for k in columns_in_order
7179+ if not is_extension_array_dtype (self .variables [k ].data )
7180+ ]
7181+ extension_array_columns = [
7182+ k
7183+ for k in columns_in_order
7184+ if is_extension_array_dtype (self .variables [k ].data )
7185+ ]
71727186 data = [
71737187 self ._variables [k ].set_dims (ordered_dims ).values .reshape (- 1 )
7174- for k in columns
7188+ for k in non_extension_array_columns
71757189 ]
71767190 index = self .coords .to_index ([* ordered_dims ])
7177- return pd .DataFrame (dict (zip (columns , data )), index = index )
7191+ broadcasted_df = pd .DataFrame (
7192+ dict (zip (non_extension_array_columns , data )), index = index
7193+ )
7194+ for extension_array_column in extension_array_columns :
7195+ extension_array = self .variables [extension_array_column ].data .array
7196+ index = self [self .variables [extension_array_column ].dims [0 ]].data
7197+ extension_array_df = pd .DataFrame (
7198+ {extension_array_column : extension_array },
7199+ index = self [self .variables [extension_array_column ].dims [0 ]].data ,
7200+ )
7201+ extension_array_df .index .name = self .variables [extension_array_column ].dims [
7202+ 0
7203+ ]
7204+ broadcasted_df = broadcasted_df .join (extension_array_df )
7205+ return broadcasted_df [columns_in_order ]
71787206
71797207 def to_dataframe (self , dim_order : Sequence [Hashable ] | None = None ) -> pd .DataFrame :
71807208 """Convert this dataset into a pandas.DataFrame.
@@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73217349 "cannot convert a DataFrame with a non-unique MultiIndex into xarray"
73227350 )
73237351
7324- # Cast to a NumPy array first, in case the Series is a pandas Extension
7325- # array (which doesn't have a valid NumPy dtype)
7326- # TODO: allow users to control how this casting happens, e.g., by
7327- # forwarding arguments to pandas.Series.to_numpy?
7328- arrays = [(k , np .asarray (v )) for k , v in dataframe .items ()]
7352+ arrays = []
7353+ extension_arrays = []
7354+ for k , v in dataframe .items ():
7355+ if not is_extension_array_dtype (v ):
7356+ arrays .append ((k , np .asarray (v )))
7357+ else :
7358+ extension_arrays .append ((k , v ))
73297359
73307360 indexes : dict [Hashable , Index ] = {}
73317361 index_vars : dict [Hashable , Variable ] = {}
@@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73397369 xr_idx = PandasIndex (lev , dim )
73407370 indexes [dim ] = xr_idx
73417371 index_vars .update (xr_idx .create_variables ())
7372+ arrays += [(k , np .asarray (v )) for k , v in extension_arrays ]
7373+ extension_arrays = []
73427374 else :
73437375 index_name = idx .name if idx .name is not None else "index"
73447376 dims = (index_name ,)
@@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73527384 obj ._set_sparse_data_from_dataframe (idx , arrays , dims )
73537385 else :
73547386 obj ._set_numpy_data_from_dataframe (idx , arrays , dims )
7355- return obj
7387+ for name , extension_array in extension_arrays :
7388+ obj [name ] = (dims , extension_array )
7389+ return obj [dataframe .columns ] if len (dataframe .columns ) else obj
73567390
73577391 def to_dask_dataframe (
73587392 self , dim_order : Sequence [Hashable ] | None = None , set_index : bool = False
0 commit comments