Skip to content

Commit 16b0bac

Browse files
committed
Merge branch 'main' into topk
* main: More support for datetime, timedelta (#412)
2 parents 6aa923a + df81a8d commit 16b0bac

8 files changed

+218
-92
lines changed

flox/aggregate_numbagg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"nanmean": {np.int_: np.float64},
3131
"nanvar": {np.int_: np.float64},
3232
"nanstd": {np.int_: np.float64},
33+
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
34+
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
3335
}
3436

3537

@@ -51,7 +53,7 @@ def _numbagg_wrapper(
5153
if cast_to:
5254
for from_, to_ in cast_to.items():
5355
if np.issubdtype(array.dtype, from_):
54-
array = array.astype(to_)
56+
array = array.astype(to_, copy=False)
5557

5658
func_ = getattr(numbagg.grouped, f"group_{func}")
5759

flox/core.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
)
4747
from .cache import memoize
4848
from .xrutils import (
49+
_contains_cftime_datetimes,
50+
_to_pytimedelta,
51+
datetime_to_numeric,
4952
is_chunked_array,
5053
is_duck_array,
5154
is_duck_cubed_array,
@@ -172,6 +175,17 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
172175
return func in ["nanfirst", "nanlast", "first", "last"]
173176

174177

178+
def _is_bool_supported_reduction(func: T_Agg) -> bool:
179+
if isinstance(func, Aggregation):
180+
func = func.name
181+
return (
182+
func in ["all", "any"]
183+
# TODO: enable in npg
184+
# or _is_first_last_reduction(func)
185+
# or _is_minmax_reduction(func)
186+
)
187+
188+
175189
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
176190
if is_duck_dask_array(by):
177191
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -2432,7 +2446,7 @@ def groupby_reduce(
24322446
array.dtype,
24332447
)
24342448

2435-
is_bool_array = np.issubdtype(array.dtype, bool)
2449+
is_bool_array = np.issubdtype(array.dtype, bool) and not _is_bool_supported_reduction(func)
24362450
array = array.astype(np.int_) if is_bool_array else array
24372451

24382452
isbins = _atleast_1d(isbin, nby)
@@ -2482,7 +2496,8 @@ def groupby_reduce(
24822496
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
24832497
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
24842498

2485-
if _is_first_last_reduction(func):
2499+
is_first_last = _is_first_last_reduction(func)
2500+
if is_first_last:
24862501
if has_dask and nax != 1:
24872502
raise ValueError(
24882503
"For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2495,6 +2510,22 @@ def groupby_reduce(
24952510
"along a single axis or when reducing across all dimensions of `by`."
24962511
)
24972512

2513+
is_npdatetime = array.dtype.kind in "Mm"
2514+
is_cftime = _contains_cftime_datetimes(array)
2515+
requires_numeric = (
2516+
(func not in ["count", "any", "all"] and not is_first_last)
2517+
# Flox's count works with non-numeric and its faster than converting.
2518+
or (func == "count" and engine != "flox")
2519+
or (is_first_last and is_cftime)
2520+
)
2521+
if requires_numeric:
2522+
if is_npdatetime:
2523+
datetime_dtype = array.dtype
2524+
array = array.view(np.int64)
2525+
elif is_cftime:
2526+
offset = array.min()
2527+
array = datetime_to_numeric(array, offset, datetime_unit="us")
2528+
24982529
if nax == 1 and by_.ndim > 1 and expected_ is None:
24992530
# When we reduce along all axes, we are guaranteed to see all
25002531
# groups in the final combine stage, so everything works.
@@ -2680,6 +2711,14 @@ def groupby_reduce(
26802711

26812712
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
26822713
result = result.astype(bool)
2714+
2715+
# Output of count has an int dtype.
2716+
if requires_numeric and func != "count":
2717+
if is_npdatetime:
2718+
result = result.astype(datetime_dtype)
2719+
elif is_cftime:
2720+
result = _to_pytimedelta(result, unit="us") + offset
2721+
26832722
return (result, *groups)
26842723

26852724

@@ -2820,6 +2859,12 @@ def groupby_scan(
28202859
(by_,) = bys
28212860
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
28222861

2862+
if array.dtype.kind in "Mm":
2863+
cast_to = array.dtype
2864+
array = array.view(np.int64)
2865+
else:
2866+
cast_to = None
2867+
28232868
# TODO: move to aggregate_npg.py
28242869
if agg.name in ["cumsum", "nancumsum"] and array.dtype.kind in ["i", "u"]:
28252870
# https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
@@ -2835,7 +2880,10 @@ def groupby_scan(
28352880
(single_axis,) = axis_ # type: ignore[misc]
28362881
# avoid some roundoff error when we can.
28372882
if by_.shape[-1] == 1 or by_.shape == grp_shape:
2838-
return array.astype(agg.dtype)
2883+
array = array.astype(agg.dtype)
2884+
if cast_to is not None:
2885+
array = array.astype(cast_to)
2886+
return array
28392887

28402888
# Made a design choice here to have `preprocess` handle both array and group_idx
28412889
# Example: for reversing, we need to reverse the whole array, not just reverse
@@ -2854,6 +2902,9 @@ def groupby_scan(
28542902
out = AlignedArrays(array=result, group_idx=by_)
28552903
if agg.finalize:
28562904
out = agg.finalize(out)
2905+
2906+
if cast_to is not None:
2907+
return out.array.astype(cast_to)
28572908
return out.array
28582909

28592910

flox/xarray.py

-24
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
import xarray as xr
99
from packaging.version import Version
10-
from xarray.core.duck_array_ops import _datetime_nanmin
1110

1211
from .aggregations import (
1312
Aggregation,
@@ -24,7 +23,6 @@
2423
)
2524
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
2625
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
27-
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
2826

2927
if TYPE_CHECKING:
3028
from xarray.core.types import T_DataArray, T_Dataset
@@ -372,22 +370,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
372370
if "nan" not in func and func not in ["all", "any", "count"]:
373371
func = f"nan{func}"
374372

375-
# Flox's count works with non-numeric and its faster than converting.
376-
requires_numeric = func not in ["count", "any", "all"] or (
377-
func == "count" and kwargs["engine"] != "flox"
378-
)
379-
if requires_numeric:
380-
is_npdatetime = array.dtype.kind in "Mm"
381-
is_cftime = _contains_cftime_datetimes(array)
382-
if is_npdatetime:
383-
offset = _datetime_nanmin(array)
384-
# xarray always uses np.datetime64[ns] for np.datetime64 data
385-
dtype = "timedelta64[ns]"
386-
array = datetime_to_numeric(array, offset)
387-
elif is_cftime:
388-
offset = array.min()
389-
array = datetime_to_numeric(array, offset, datetime_unit="us")
390-
391373
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
392374

393375
# Transpose the new quantile or topk dimension to the end. This is ugly.
@@ -404,12 +386,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
404386
# This transpose is simply makes it easy to specify output_core_dims
405387
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
406388
result = np.moveaxis(result, 0, -1)
407-
# Output of count has an int dtype.
408-
if requires_numeric and func != "count":
409-
if is_npdatetime:
410-
return result.astype(dtype) + offset
411-
elif is_cftime:
412-
return _to_pytimedelta(result, unit="us") + offset
413389

414390
return result
415391

flox/xrutils.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,6 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
213213
"""
214214
# TODO: make this function dask-compatible?
215215
# Set offset to minimum if not given
216-
from xarray.core.duck_array_ops import _datetime_nanmin
217-
218216
if offset is None:
219217
if array.dtype.kind in "Mm":
220218
offset = _datetime_nanmin(array)
@@ -345,6 +343,28 @@ def _contains_cftime_datetimes(array) -> bool:
345343
return False
346344

347345

346+
def _datetime_nanmin(array):
347+
"""nanmin() function for datetime64.
348+
349+
Caveats that this function deals with:
350+
351+
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
352+
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
353+
- dask min() does not work on datetime64 (all versions at the moment of writing)
354+
"""
355+
from .xrdtypes import is_datetime_like
356+
357+
dtype = array.dtype
358+
assert is_datetime_like(dtype)
359+
# (NaT).astype(float) does not produce NaN...
360+
array = np.where(pd.isnull(array), np.nan, array.astype(float))
361+
array = np.nanmin(array)
362+
if isinstance(array, float):
363+
array = np.array(array)
364+
# ...but (NaN).astype("M8") does produce NaT
365+
return array.astype(dtype)
366+
367+
348368
def _select_along_axis(values, idx, axis):
349369
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
350370
sl = other_ind[:axis] + (idx,) + other_ind[axis:]

tests/strategies.py

+49-40
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,6 @@
1313

1414
Chunks = tuple[tuple[int, ...], ...]
1515

16-
17-
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
18-
return (
19-
npst.integer_dtypes(endianness="=")
20-
| npst.unsigned_integer_dtypes(endianness="=")
21-
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
22-
| npst.complex_number_dtypes(endianness="=")
23-
| npst.datetime64_dtypes(endianness="=")
24-
| npst.timedelta64_dtypes(endianness="=")
25-
| npst.unicode_string_dtypes(endianness="=")
26-
)
27-
28-
29-
# TODO: stop excluding everything but U
30-
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
31-
by_dtype_st = supported_dtypes()
32-
33-
NON_NUMPY_FUNCS = [
34-
"first",
35-
"last",
36-
"nanfirst",
37-
"nanlast",
38-
"count",
39-
"any",
40-
"all",
41-
] + list(SCIPY_STATS_FUNCS)
42-
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
43-
44-
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
45-
numeric_arrays = npst.arrays(
46-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
47-
)
48-
all_arrays = npst.arrays(
49-
elements={"allow_subnormal": False},
50-
shape=npst.array_shapes(),
51-
dtype=supported_dtypes(),
52-
)
53-
5416
calendars = st.sampled_from(
5517
[
5618
"standard",
@@ -89,7 +51,7 @@ def units(draw, *, calendar: str) -> str:
8951
def cftime_arrays(
9052
draw: st.DrawFn,
9153
*,
92-
shape: tuple[int, ...],
54+
shape: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes(),
9355
calendars: st.SearchStrategy[str] = calendars,
9456
elements: dict[str, Any] | None = None,
9557
) -> np.ndarray[Any, Any]:
@@ -103,8 +65,55 @@ def cftime_arrays(
10365
return cftime.num2date(values, units=unit, calendar=cal)
10466

10567

68+
numeric_dtypes = (
69+
npst.integer_dtypes(endianness="=")
70+
| npst.unsigned_integer_dtypes(endianness="=")
71+
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
72+
# TODO: add complex here not in supported_dtypes
73+
)
74+
numeric_like_dtypes = (
75+
npst.boolean_dtypes()
76+
| numeric_dtypes
77+
| npst.datetime64_dtypes(endianness="=")
78+
| npst.timedelta64_dtypes(endianness="=")
79+
)
80+
supported_dtypes = (
81+
numeric_like_dtypes
82+
| npst.unicode_string_dtypes(endianness="=")
83+
| npst.complex_number_dtypes(endianness="=")
84+
)
85+
by_dtype_st = supported_dtypes
86+
87+
NON_NUMPY_FUNCS = [
88+
"first",
89+
"last",
90+
"nanfirst",
91+
"nanlast",
92+
"count",
93+
"any",
94+
"all",
95+
] + list(SCIPY_STATS_FUNCS)
96+
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
97+
98+
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
99+
numeric_arrays = npst.arrays(
100+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
101+
)
102+
numeric_like_arrays = npst.arrays(
103+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
104+
)
105+
all_arrays = (
106+
npst.arrays(
107+
elements={"allow_subnormal": False},
108+
shape=npst.array_shapes(),
109+
dtype=numeric_like_dtypes,
110+
)
111+
| cftime_arrays()
112+
)
113+
114+
106115
def by_arrays(
107-
shape: tuple[int, ...], *, elements: dict[str, Any] | None = None
116+
shape: st.SearchStrategy[tuple[int, ...]], *, elements: dict[str, Any] | None = None
108117
) -> st.SearchStrategy[np.ndarray[Any, Any]]:
109118
if elements is None:
110119
elements = {}

tests/test_core.py

+29
Original file line numberDiff line numberDiff line change
@@ -2007,3 +2007,32 @@ def test_blockwise_avoid_rechunk():
20072007
actual, groups = groupby_reduce(array, by, func="first")
20082008
assert_equal(groups, ["", "0", "1"])
20092009
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))
2010+
2011+
2012+
def test_datetime_minmax(engine):
2013+
# GH403
2014+
array = np.array([np.datetime64("2000-01-01"), np.datetime64("2000-01-02"), np.datetime64("2000-01-03")])
2015+
by = np.array([0, 0, 1])
2016+
actual, _ = flox.groupby_reduce(array, by, func="nanmin", engine=engine)
2017+
expected = array[[0, 2]]
2018+
assert_equal(expected, actual)
2019+
2020+
expected = array[[1, 2]]
2021+
actual, _ = flox.groupby_reduce(array, by, func="nanmax", engine=engine)
2022+
assert_equal(expected, actual)
2023+
2024+
2025+
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
2026+
def test_datetime_timedelta_first_last(engine, func):
2027+
import flox
2028+
2029+
idx = 0 if "first" in func else -1
2030+
2031+
dt = pd.date_range("2001-01-01", freq="d", periods=5).values
2032+
by = np.ones(dt.shape, dtype=int)
2033+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2034+
assert_equal(actual, dt[[idx]])
2035+
2036+
dt = dt - dt[0]
2037+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2038+
assert_equal(actual, dt[[idx]])

0 commit comments

Comments
 (0)