46
46
)
47
47
from .cache import memoize
48
48
from .xrutils import (
49
+ _contains_cftime_datetimes ,
50
+ _to_pytimedelta ,
51
+ datetime_to_numeric ,
49
52
is_chunked_array ,
50
53
is_duck_array ,
51
54
is_duck_cubed_array ,
@@ -172,6 +175,17 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
172
175
return func in ["nanfirst" , "nanlast" , "first" , "last" ]
173
176
174
177
178
+ def _is_bool_supported_reduction (func : T_Agg ) -> bool :
179
+ if isinstance (func , Aggregation ):
180
+ func = func .name
181
+ return (
182
+ func in ["all" , "any" ]
183
+ # TODO: enable in npg
184
+ # or _is_first_last_reduction(func)
185
+ # or _is_minmax_reduction(func)
186
+ )
187
+
188
+
175
189
def _get_expected_groups (by : T_By , sort : bool ) -> T_ExpectIndex :
176
190
if is_duck_dask_array (by ):
177
191
raise ValueError ("Please provide expected_groups if not grouping by a numpy array." )
@@ -2432,7 +2446,7 @@ def groupby_reduce(
2432
2446
array .dtype ,
2433
2447
)
2434
2448
2435
- is_bool_array = np .issubdtype (array .dtype , bool )
2449
+ is_bool_array = np .issubdtype (array .dtype , bool ) and not _is_bool_supported_reduction ( func )
2436
2450
array = array .astype (np .int_ ) if is_bool_array else array
2437
2451
2438
2452
isbins = _atleast_1d (isbin , nby )
@@ -2482,7 +2496,8 @@ def groupby_reduce(
2482
2496
has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
2483
2497
has_cubed = is_duck_cubed_array (array ) or is_duck_cubed_array (by_ )
2484
2498
2485
- if _is_first_last_reduction (func ):
2499
+ is_first_last = _is_first_last_reduction (func )
2500
+ if is_first_last :
2486
2501
if has_dask and nax != 1 :
2487
2502
raise ValueError (
2488
2503
"For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2495,6 +2510,22 @@ def groupby_reduce(
2495
2510
"along a single axis or when reducing across all dimensions of `by`."
2496
2511
)
2497
2512
2513
+ is_npdatetime = array .dtype .kind in "Mm"
2514
+ is_cftime = _contains_cftime_datetimes (array )
2515
+ requires_numeric = (
2516
+ (func not in ["count" , "any" , "all" ] and not is_first_last )
2517
+ # Flox's count works with non-numeric and its faster than converting.
2518
+ or (func == "count" and engine != "flox" )
2519
+ or (is_first_last and is_cftime )
2520
+ )
2521
+ if requires_numeric :
2522
+ if is_npdatetime :
2523
+ datetime_dtype = array .dtype
2524
+ array = array .view (np .int64 )
2525
+ elif is_cftime :
2526
+ offset = array .min ()
2527
+ array = datetime_to_numeric (array , offset , datetime_unit = "us" )
2528
+
2498
2529
if nax == 1 and by_ .ndim > 1 and expected_ is None :
2499
2530
# When we reduce along all axes, we are guaranteed to see all
2500
2531
# groups in the final combine stage, so everything works.
@@ -2680,6 +2711,14 @@ def groupby_reduce(
2680
2711
2681
2712
if is_bool_array and (_is_minmax_reduction (func ) or _is_first_last_reduction (func )):
2682
2713
result = result .astype (bool )
2714
+
2715
+ # Output of count has an int dtype.
2716
+ if requires_numeric and func != "count" :
2717
+ if is_npdatetime :
2718
+ result = result .astype (datetime_dtype )
2719
+ elif is_cftime :
2720
+ result = _to_pytimedelta (result , unit = "us" ) + offset
2721
+
2683
2722
return (result , * groups )
2684
2723
2685
2724
@@ -2820,6 +2859,12 @@ def groupby_scan(
2820
2859
(by_ ,) = bys
2821
2860
has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
2822
2861
2862
+ if array .dtype .kind in "Mm" :
2863
+ cast_to = array .dtype
2864
+ array = array .view (np .int64 )
2865
+ else :
2866
+ cast_to = None
2867
+
2823
2868
# TODO: move to aggregate_npg.py
2824
2869
if agg .name in ["cumsum" , "nancumsum" ] and array .dtype .kind in ["i" , "u" ]:
2825
2870
# https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
@@ -2835,7 +2880,10 @@ def groupby_scan(
2835
2880
(single_axis ,) = axis_ # type: ignore[misc]
2836
2881
# avoid some roundoff error when we can.
2837
2882
if by_ .shape [- 1 ] == 1 or by_ .shape == grp_shape :
2838
- return array .astype (agg .dtype )
2883
+ array = array .astype (agg .dtype )
2884
+ if cast_to is not None :
2885
+ array = array .astype (cast_to )
2886
+ return array
2839
2887
2840
2888
# Made a design choice here to have `preprocess` handle both array and group_idx
2841
2889
# Example: for reversing, we need to reverse the whole array, not just reverse
@@ -2854,6 +2902,9 @@ def groupby_scan(
2854
2902
out = AlignedArrays (array = result , group_idx = by_ )
2855
2903
if agg .finalize :
2856
2904
out = agg .finalize (out )
2905
+
2906
+ if cast_to is not None :
2907
+ return out .array .astype (cast_to )
2857
2908
return out .array
2858
2909
2859
2910
0 commit comments