Skip to content

Commit 9303302

Browse files
authored
BUG: nullable groupby result dtypes (#46197)
1 parent aec0f57 commit 9303302

File tree

3 files changed

+46
-31
lines changed

3 files changed

+46
-31
lines changed

pandas/core/dtypes/dtypes.py

+20
Original file line numberDiff line numberDiff line change
@@ -1417,3 +1417,23 @@ def construct_array_type(cls) -> type_t[BaseMaskedArray]:
14171417
type
14181418
"""
14191419
raise NotImplementedError
1420+
1421+
@classmethod
1422+
def from_numpy_dtype(cls, dtype: np.dtype) -> BaseMaskedDtype:
1423+
"""
1424+
Construct the MaskedDtype corresponding to the given numpy dtype.
1425+
"""
1426+
if dtype.kind == "b":
1427+
from pandas.core.arrays.boolean import BooleanDtype
1428+
1429+
return BooleanDtype()
1430+
elif dtype.kind in ["i", "u"]:
1431+
from pandas.core.arrays.integer import INT_STR_TO_DTYPE
1432+
1433+
return INT_STR_TO_DTYPE[dtype.name]
1434+
elif dtype.kind == "f":
1435+
from pandas.core.arrays.floating import FLOAT_STR_TO_DTYPE
1436+
1437+
return FLOAT_STR_TO_DTYPE[dtype.name]
1438+
else:
1439+
raise NotImplementedError(dtype)

pandas/core/groupby/ops.py

+19-30
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Iterator,
1717
Sequence,
1818
final,
19-
overload,
2019
)
2120

2221
import numpy as np
@@ -57,7 +56,6 @@
5756
is_timedelta64_dtype,
5857
needs_i8_conversion,
5958
)
60-
from pandas.core.dtypes.dtypes import ExtensionDtype
6159
from pandas.core.dtypes.missing import (
6260
isna,
6361
maybe_fill,
@@ -70,14 +68,8 @@
7068
TimedeltaArray,
7169
)
7270
from pandas.core.arrays.boolean import BooleanDtype
73-
from pandas.core.arrays.floating import (
74-
Float64Dtype,
75-
FloatingDtype,
76-
)
77-
from pandas.core.arrays.integer import (
78-
Int64Dtype,
79-
IntegerDtype,
80-
)
71+
from pandas.core.arrays.floating import FloatingDtype
72+
from pandas.core.arrays.integer import IntegerDtype
8173
from pandas.core.arrays.masked import (
8274
BaseMaskedArray,
8375
BaseMaskedDtype,
@@ -277,41 +269,27 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
277269
out_dtype = "object"
278270
return np.dtype(out_dtype)
279271

280-
@overload
281272
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
282-
... # pragma: no cover
283-
284-
@overload
285-
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
286-
... # pragma: no cover
287-
288-
# TODO: general case implementation overridable by EAs.
289-
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
290273
"""
291274
Get the desired dtype of a result based on the
292275
input dtype and how it was computed.
293276
294277
Parameters
295278
----------
296-
dtype : np.dtype or ExtensionDtype
297-
Input dtype.
279+
dtype : np.dtype
298280
299281
Returns
300282
-------
301-
np.dtype or ExtensionDtype
283+
np.dtype
302284
The desired dtype of the result.
303285
"""
304286
how = self.how
305287

306288
if how in ["add", "cumsum", "sum", "prod"]:
307289
if dtype == np.dtype(bool):
308290
return np.dtype(np.int64)
309-
elif isinstance(dtype, (BooleanDtype, IntegerDtype)):
310-
return Int64Dtype()
311291
elif how in ["mean", "median", "var"]:
312-
if isinstance(dtype, (BooleanDtype, IntegerDtype)):
313-
return Float64Dtype()
314-
elif is_float_dtype(dtype) or is_complex_dtype(dtype):
292+
if is_float_dtype(dtype) or is_complex_dtype(dtype):
315293
return dtype
316294
elif is_numeric_dtype(dtype):
317295
return np.dtype(np.float64)
@@ -390,8 +368,18 @@ def _reconstruct_ea_result(
390368
Construct an ExtensionArray result from an ndarray result.
391369
"""
392370

393-
if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
394-
dtype = self._get_result_dtype(values.dtype)
371+
if isinstance(values.dtype, StringDtype):
372+
dtype = values.dtype
373+
cls = dtype.construct_array_type()
374+
return cls._from_sequence(res_values, dtype=dtype)
375+
376+
elif isinstance(values.dtype, BaseMaskedDtype):
377+
new_dtype = self._get_result_dtype(values.dtype.numpy_dtype)
378+
# error: Incompatible types in assignment (expression has type
379+
# "BaseMaskedDtype", variable has type "StringDtype")
380+
dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment]
381+
new_dtype
382+
)
395383
cls = dtype.construct_array_type()
396384
return cls._from_sequence(res_values, dtype=dtype)
397385

@@ -433,7 +421,8 @@ def _masked_ea_wrap_cython_operation(
433421
**kwargs,
434422
)
435423

436-
dtype = self._get_result_dtype(orig_values.dtype)
424+
new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype)
425+
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
437426
# TODO: avoid cast as res_values *should* already have the right
438427
# dtype; last attempt ran into trouble on 32bit linux build
439428
res_values = res_values.astype(dtype.type, copy=False)

pandas/tests/groupby/aggregate/test_cython.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import numpy as np
66
import pytest
77

8-
from pandas.core.dtypes.common import is_float_dtype
8+
from pandas.core.dtypes.common import (
9+
is_float_dtype,
10+
is_integer_dtype,
11+
)
912

1013
import pandas as pd
1114
from pandas import (
@@ -369,6 +372,9 @@ def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na):
369372
# for any int/bool use Int64, for float preserve dtype
370373
if is_float_dtype(data.dtype):
371374
expected_dtype = data.dtype
375+
elif is_integer_dtype(data.dtype):
376+
# match the numpy dtype we'd get with the non-nullable analogue
377+
expected_dtype = data.dtype
372378
else:
373379
expected_dtype = pd.Int64Dtype()
374380
elif action == "always_float":

0 commit comments

Comments
 (0)