Skip to content

Commit df81a8d

Browse files
authored
More support for datetime, timedelta (#412)
* More support for datetime, timedelta Closes #403 * Rework property tests * Add cftime property tests * add cftime unit test * Smarter bool conversion * typing * more typing * cubed bugfix * xfail one more * xfail nanprod too
1 parent 0344a28 commit df81a8d

8 files changed

+217
-92
lines changed

flox/aggregate_numbagg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"nanmean": {np.int_: np.float64},
3131
"nanvar": {np.int_: np.float64},
3232
"nanstd": {np.int_: np.float64},
33+
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
34+
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
3335
}
3436

3537

@@ -51,7 +53,7 @@ def _numbagg_wrapper(
5153
if cast_to:
5254
for from_, to_ in cast_to.items():
5355
if np.issubdtype(array.dtype, from_):
54-
array = array.astype(to_)
56+
array = array.astype(to_, copy=False)
5557

5658
func_ = getattr(numbagg.grouped, f"group_{func}")
5759

flox/core.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
)
4646
from .cache import memoize
4747
from .xrutils import (
48+
_contains_cftime_datetimes,
49+
_to_pytimedelta,
50+
datetime_to_numeric,
4851
is_chunked_array,
4952
is_duck_array,
5053
is_duck_cubed_array,
@@ -171,6 +174,17 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
171174
return func in ["nanfirst", "nanlast", "first", "last"]
172175

173176

177+
def _is_bool_supported_reduction(func: T_Agg) -> bool:
178+
if isinstance(func, Aggregation):
179+
func = func.name
180+
return (
181+
func in ["all", "any"]
182+
# TODO: enable in npg
183+
# or _is_first_last_reduction(func)
184+
# or _is_minmax_reduction(func)
185+
)
186+
187+
174188
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
175189
if is_duck_dask_array(by):
176190
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -2422,7 +2436,7 @@ def groupby_reduce(
24222436
array.dtype,
24232437
)
24242438

2425-
is_bool_array = np.issubdtype(array.dtype, bool)
2439+
is_bool_array = np.issubdtype(array.dtype, bool) and not _is_bool_supported_reduction(func)
24262440
array = array.astype(np.int_) if is_bool_array else array
24272441

24282442
isbins = _atleast_1d(isbin, nby)
@@ -2472,7 +2486,8 @@ def groupby_reduce(
24722486
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
24732487
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
24742488

2475-
if _is_first_last_reduction(func):
2489+
is_first_last = _is_first_last_reduction(func)
2490+
if is_first_last:
24762491
if has_dask and nax != 1:
24772492
raise ValueError(
24782493
"For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2485,6 +2500,22 @@ def groupby_reduce(
24852500
"along a single axis or when reducing across all dimensions of `by`."
24862501
)
24872502

2503+
is_npdatetime = array.dtype.kind in "Mm"
2504+
is_cftime = _contains_cftime_datetimes(array)
2505+
requires_numeric = (
2506+
(func not in ["count", "any", "all"] and not is_first_last)
2507+
# Flox's count works with non-numeric and its faster than converting.
2508+
or (func == "count" and engine != "flox")
2509+
or (is_first_last and is_cftime)
2510+
)
2511+
if requires_numeric:
2512+
if is_npdatetime:
2513+
datetime_dtype = array.dtype
2514+
array = array.view(np.int64)
2515+
elif is_cftime:
2516+
offset = array.min()
2517+
array = datetime_to_numeric(array, offset, datetime_unit="us")
2518+
24882519
if nax == 1 and by_.ndim > 1 and expected_ is None:
24892520
# When we reduce along all axes, we are guaranteed to see all
24902521
# groups in the final combine stage, so everything works.
@@ -2670,6 +2701,14 @@ def groupby_reduce(
26702701

26712702
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
26722703
result = result.astype(bool)
2704+
2705+
# Output of count has an int dtype.
2706+
if requires_numeric and func != "count":
2707+
if is_npdatetime:
2708+
result = result.astype(datetime_dtype)
2709+
elif is_cftime:
2710+
result = _to_pytimedelta(result, unit="us") + offset
2711+
26732712
return (result, *groups)
26742713

26752714

@@ -2810,6 +2849,12 @@ def groupby_scan(
28102849
(by_,) = bys
28112850
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
28122851

2852+
if array.dtype.kind in "Mm":
2853+
cast_to = array.dtype
2854+
array = array.view(np.int64)
2855+
else:
2856+
cast_to = None
2857+
28132858
# TODO: move to aggregate_npg.py
28142859
if agg.name in ["cumsum", "nancumsum"] and array.dtype.kind in ["i", "u"]:
28152860
# https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
@@ -2825,7 +2870,10 @@ def groupby_scan(
28252870
(single_axis,) = axis_ # type: ignore[misc]
28262871
# avoid some roundoff error when we can.
28272872
if by_.shape[-1] == 1 or by_.shape == grp_shape:
2828-
return array.astype(agg.dtype)
2873+
array = array.astype(agg.dtype)
2874+
if cast_to is not None:
2875+
array = array.astype(cast_to)
2876+
return array
28292877

28302878
# Made a design choice here to have `preprocess` handle both array and group_idx
28312879
# Example: for reversing, we need to reverse the whole array, not just reverse
@@ -2844,6 +2892,9 @@ def groupby_scan(
28442892
out = AlignedArrays(array=result, group_idx=by_)
28452893
if agg.finalize:
28462894
out = agg.finalize(out)
2895+
2896+
if cast_to is not None:
2897+
return out.array.astype(cast_to)
28472898
return out.array
28482899

28492900

flox/xarray.py

-25
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
import xarray as xr
99
from packaging.version import Version
10-
from xarray.core.duck_array_ops import _datetime_nanmin
1110

1211
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
1312
from .core import (
@@ -18,7 +17,6 @@
1817
)
1918
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
2019
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
21-
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
2220

2321
if TYPE_CHECKING:
2422
from xarray.core.types import T_DataArray, T_Dataset
@@ -366,22 +364,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
366364
if "nan" not in func and func not in ["all", "any", "count"]:
367365
func = f"nan{func}"
368366

369-
# Flox's count works with non-numeric and its faster than converting.
370-
requires_numeric = func not in ["count", "any", "all"] or (
371-
func == "count" and kwargs["engine"] != "flox"
372-
)
373-
if requires_numeric:
374-
is_npdatetime = array.dtype.kind in "Mm"
375-
is_cftime = _contains_cftime_datetimes(array)
376-
if is_npdatetime:
377-
offset = _datetime_nanmin(array)
378-
# xarray always uses np.datetime64[ns] for np.datetime64 data
379-
dtype = "timedelta64[ns]"
380-
array = datetime_to_numeric(array, offset)
381-
elif is_cftime:
382-
offset = array.min()
383-
array = datetime_to_numeric(array, offset, datetime_unit="us")
384-
385367
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
386368

387369
# Transpose the new quantile dimension to the end. This is ugly.
@@ -395,13 +377,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
395377
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
396378
result = np.moveaxis(result, 0, -1)
397379

398-
# Output of count has an int dtype.
399-
if requires_numeric and func != "count":
400-
if is_npdatetime:
401-
return result.astype(dtype) + offset
402-
elif is_cftime:
403-
return _to_pytimedelta(result, unit="us") + offset
404-
405380
return result
406381

407382
# These data variables do not have any of the core dimension,

flox/xrutils.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,6 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
213213
"""
214214
# TODO: make this function dask-compatible?
215215
# Set offset to minimum if not given
216-
from xarray.core.duck_array_ops import _datetime_nanmin
217-
218216
if offset is None:
219217
if array.dtype.kind in "Mm":
220218
offset = _datetime_nanmin(array)
@@ -345,6 +343,28 @@ def _contains_cftime_datetimes(array) -> bool:
345343
return False
346344

347345

346+
def _datetime_nanmin(array):
347+
"""nanmin() function for datetime64.
348+
349+
Caveats that this function deals with:
350+
351+
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
352+
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
353+
- dask min() does not work on datetime64 (all versions at the moment of writing)
354+
"""
355+
from .xrdtypes import is_datetime_like
356+
357+
dtype = array.dtype
358+
assert is_datetime_like(dtype)
359+
# (NaT).astype(float) does not produce NaN...
360+
array = np.where(pd.isnull(array), np.nan, array.astype(float))
361+
array = np.nanmin(array)
362+
if isinstance(array, float):
363+
array = np.array(array)
364+
# ...but (NaN).astype("M8") does produce NaT
365+
return array.astype(dtype)
366+
367+
348368
def _select_along_axis(values, idx, axis):
349369
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
350370
sl = other_ind[:axis] + (idx,) + other_ind[axis:]

tests/strategies.py

+49-40
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,6 @@
1313

1414
Chunks = tuple[tuple[int, ...], ...]
1515

16-
17-
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
18-
return (
19-
npst.integer_dtypes(endianness="=")
20-
| npst.unsigned_integer_dtypes(endianness="=")
21-
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
22-
| npst.complex_number_dtypes(endianness="=")
23-
| npst.datetime64_dtypes(endianness="=")
24-
| npst.timedelta64_dtypes(endianness="=")
25-
| npst.unicode_string_dtypes(endianness="=")
26-
)
27-
28-
29-
# TODO: stop excluding everything but U
30-
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
31-
by_dtype_st = supported_dtypes()
32-
33-
NON_NUMPY_FUNCS = [
34-
"first",
35-
"last",
36-
"nanfirst",
37-
"nanlast",
38-
"count",
39-
"any",
40-
"all",
41-
] + list(SCIPY_STATS_FUNCS)
42-
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
43-
44-
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
45-
numeric_arrays = npst.arrays(
46-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
47-
)
48-
all_arrays = npst.arrays(
49-
elements={"allow_subnormal": False},
50-
shape=npst.array_shapes(),
51-
dtype=supported_dtypes(),
52-
)
53-
5416
calendars = st.sampled_from(
5517
[
5618
"standard",
@@ -89,7 +51,7 @@ def units(draw, *, calendar: str) -> str:
8951
def cftime_arrays(
9052
draw: st.DrawFn,
9153
*,
92-
shape: tuple[int, ...],
54+
shape: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes(),
9355
calendars: st.SearchStrategy[str] = calendars,
9456
elements: dict[str, Any] | None = None,
9557
) -> np.ndarray[Any, Any]:
@@ -103,8 +65,55 @@ def cftime_arrays(
10365
return cftime.num2date(values, units=unit, calendar=cal)
10466

10567

68+
numeric_dtypes = (
69+
npst.integer_dtypes(endianness="=")
70+
| npst.unsigned_integer_dtypes(endianness="=")
71+
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
72+
# TODO: add complex here not in supported_dtypes
73+
)
74+
numeric_like_dtypes = (
75+
npst.boolean_dtypes()
76+
| numeric_dtypes
77+
| npst.datetime64_dtypes(endianness="=")
78+
| npst.timedelta64_dtypes(endianness="=")
79+
)
80+
supported_dtypes = (
81+
numeric_like_dtypes
82+
| npst.unicode_string_dtypes(endianness="=")
83+
| npst.complex_number_dtypes(endianness="=")
84+
)
85+
by_dtype_st = supported_dtypes
86+
87+
NON_NUMPY_FUNCS = [
88+
"first",
89+
"last",
90+
"nanfirst",
91+
"nanlast",
92+
"count",
93+
"any",
94+
"all",
95+
] + list(SCIPY_STATS_FUNCS)
96+
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
97+
98+
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
99+
numeric_arrays = npst.arrays(
100+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
101+
)
102+
numeric_like_arrays = npst.arrays(
103+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
104+
)
105+
all_arrays = (
106+
npst.arrays(
107+
elements={"allow_subnormal": False},
108+
shape=npst.array_shapes(),
109+
dtype=numeric_like_dtypes,
110+
)
111+
| cftime_arrays()
112+
)
113+
114+
106115
def by_arrays(
107-
shape: tuple[int, ...], *, elements: dict[str, Any] | None = None
116+
shape: st.SearchStrategy[tuple[int, ...]], *, elements: dict[str, Any] | None = None
108117
) -> st.SearchStrategy[np.ndarray[Any, Any]]:
109118
if elements is None:
110119
elements = {}

tests/test_core.py

+29
Original file line numberDiff line numberDiff line change
@@ -2007,3 +2007,32 @@ def test_blockwise_avoid_rechunk():
20072007
actual, groups = groupby_reduce(array, by, func="first")
20082008
assert_equal(groups, ["", "0", "1"])
20092009
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))
2010+
2011+
2012+
def test_datetime_minmax(engine):
2013+
# GH403
2014+
array = np.array([np.datetime64("2000-01-01"), np.datetime64("2000-01-02"), np.datetime64("2000-01-03")])
2015+
by = np.array([0, 0, 1])
2016+
actual, _ = flox.groupby_reduce(array, by, func="nanmin", engine=engine)
2017+
expected = array[[0, 2]]
2018+
assert_equal(expected, actual)
2019+
2020+
expected = array[[1, 2]]
2021+
actual, _ = flox.groupby_reduce(array, by, func="nanmax", engine=engine)
2022+
assert_equal(expected, actual)
2023+
2024+
2025+
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
2026+
def test_datetime_timedelta_first_last(engine, func):
2027+
import flox
2028+
2029+
idx = 0 if "first" in func else -1
2030+
2031+
dt = pd.date_range("2001-01-01", freq="d", periods=5).values
2032+
by = np.ones(dt.shape, dtype=int)
2033+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2034+
assert_equal(actual, dt[[idx]])
2035+
2036+
dt = dt - dt[0]
2037+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2038+
assert_equal(actual, dt[[idx]])

0 commit comments

Comments
 (0)