4646)
4747from .cache import memoize
4848from .xrutils import (
49+ _contains_cftime_datetimes ,
50+ _to_pytimedelta ,
51+ datetime_to_numeric ,
4952 is_chunked_array ,
5053 is_duck_array ,
5154 is_duck_cubed_array ,
@@ -172,6 +175,17 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
172175 return func in ["nanfirst" , "nanlast" , "first" , "last" ]
173176
174177
178+ def _is_bool_supported_reduction (func : T_Agg ) -> bool :
179+ if isinstance (func , Aggregation ):
180+ func = func .name
181+ return (
182+ func in ["all" , "any" ]
183+ # TODO: enable in npg
184+ # or _is_first_last_reduction(func)
185+ # or _is_minmax_reduction(func)
186+ )
187+
188+
175189def _get_expected_groups (by : T_By , sort : bool ) -> T_ExpectIndex :
176190 if is_duck_dask_array (by ):
177191 raise ValueError ("Please provide expected_groups if not grouping by a numpy array." )
@@ -2432,7 +2446,7 @@ def groupby_reduce(
24322446 array .dtype ,
24332447 )
24342448
2435- is_bool_array = np .issubdtype (array .dtype , bool )
2449+ is_bool_array = np .issubdtype (array .dtype , bool ) and not _is_bool_supported_reduction ( func )
24362450 array = array .astype (np .int_ ) if is_bool_array else array
24372451
24382452 isbins = _atleast_1d (isbin , nby )
@@ -2482,7 +2496,8 @@ def groupby_reduce(
24822496 has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
24832497 has_cubed = is_duck_cubed_array (array ) or is_duck_cubed_array (by_ )
24842498
2485- if _is_first_last_reduction (func ):
2499+ is_first_last = _is_first_last_reduction (func )
2500+ if is_first_last :
24862501 if has_dask and nax != 1 :
24872502 raise ValueError (
24882503 "For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2495,6 +2510,22 @@ def groupby_reduce(
24952510 "along a single axis or when reducing across all dimensions of `by`."
24962511 )
24972512
2513+ is_npdatetime = array .dtype .kind in "Mm"
2514+ is_cftime = _contains_cftime_datetimes (array )
2515+ requires_numeric = (
2516+ (func not in ["count" , "any" , "all" ] and not is_first_last )
2517+ # Flox's count works with non-numeric and its faster than converting.
2518+ or (func == "count" and engine != "flox" )
2519+ or (is_first_last and is_cftime )
2520+ )
2521+ if requires_numeric :
2522+ if is_npdatetime :
2523+ datetime_dtype = array .dtype
2524+ array = array .view (np .int64 )
2525+ elif is_cftime :
2526+ offset = array .min ()
2527+ array = datetime_to_numeric (array , offset , datetime_unit = "us" )
2528+
24982529 if nax == 1 and by_ .ndim > 1 and expected_ is None :
24992530 # When we reduce along all axes, we are guaranteed to see all
25002531 # groups in the final combine stage, so everything works.
@@ -2680,6 +2711,14 @@ def groupby_reduce(
26802711
26812712 if is_bool_array and (_is_minmax_reduction (func ) or _is_first_last_reduction (func )):
26822713 result = result .astype (bool )
2714+
2715+ # Output of count has an int dtype.
2716+ if requires_numeric and func != "count" :
2717+ if is_npdatetime :
2718+ result = result .astype (datetime_dtype )
2719+ elif is_cftime :
2720+ result = _to_pytimedelta (result , unit = "us" ) + offset
2721+
26832722 return (result , * groups )
26842723
26852724
@@ -2820,6 +2859,12 @@ def groupby_scan(
28202859 (by_ ,) = bys
28212860 has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
28222861
2862+ if array .dtype .kind in "Mm" :
2863+ cast_to = array .dtype
2864+ array = array .view (np .int64 )
2865+ else :
2866+ cast_to = None
2867+
28232868 # TODO: move to aggregate_npg.py
28242869 if agg .name in ["cumsum" , "nancumsum" ] and array .dtype .kind in ["i" , "u" ]:
28252870 # https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
@@ -2835,7 +2880,10 @@ def groupby_scan(
28352880 (single_axis ,) = axis_ # type: ignore[misc]
28362881 # avoid some roundoff error when we can.
28372882 if by_ .shape [- 1 ] == 1 or by_ .shape == grp_shape :
2838- return array .astype (agg .dtype )
2883+ array = array .astype (agg .dtype )
2884+ if cast_to is not None :
2885+ array = array .astype (cast_to )
2886+ return array
28392887
28402888 # Made a design choice here to have `preprocess` handle both array and group_idx
28412889 # Example: for reversing, we need to reverse the whole array, not just reverse
@@ -2854,6 +2902,9 @@ def groupby_scan(
28542902 out = AlignedArrays (array = result , group_idx = by_ )
28552903 if agg .finalize :
28562904 out = agg .finalize (out )
2905+
2906+ if cast_to is not None :
2907+ return out .array .astype (cast_to )
28572908 return out .array
28582909
28592910
0 commit comments