Skip to content

Commit 4dbadae

Browse files
authoredAug 14, 2024··
Avoid explicit np.nan, np.inf (#383)
* Handle dtypes.NA properly for datetime/timedelta * Add Aggregation.preserves_dtype * Fix ffill, bfill
1 parent f0ce343 commit 4dbadae

File tree

4 files changed

+104
-80
lines changed

4 files changed

+104
-80
lines changed
 

‎flox/aggregate_flox.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from . import xrdtypes as dtypes
56
from .xrutils import is_scalar, isnull, notnull
67

78

@@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
9899
# partition the complex array in-place
99100
labels_broadcast = np.broadcast_to(group_idx, array.shape)
100101
with np.errstate(invalid="ignore"):
101-
cmplx = labels_broadcast + 1j * array
102+
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
102103
cmplx.partition(kth=kth, axis=-1)
103104
if is_scalar_q:
104105
a_ = cmplx.imag
@@ -158,6 +159,8 @@ def _np_grouped_op(
158159

159160

160161
def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
162+
if fillna in [dtypes.INF, dtypes.NINF]:
163+
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
161164
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
162165
# np.nanmax([np.nan, np.nan]) = np.nan
163166
# To recover this behaviour, we need to search for the fillna value
@@ -175,9 +178,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
175178
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
176179
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
177180
max = partial(_np_grouped_op, op=np.maximum.reduceat)
178-
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
181+
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
179182
min = partial(_np_grouped_op, op=np.minimum.reduceat)
180-
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
183+
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
181184
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
182185
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
183186
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))

‎flox/aggregations.py

+37-72
Original file line numberDiff line numberDiff line change
@@ -115,60 +115,6 @@ def generic_aggregate(
115115
return result
116116

117117

118-
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
119-
if dtype is None:
120-
dtype = array_dtype
121-
if dtype is np.floating:
122-
# mean, std, var always result in floating
123-
# but we preserve the array's dtype if it is floating
124-
if array_dtype.kind in "fcmM":
125-
dtype = array_dtype
126-
else:
127-
dtype = np.dtype("float64")
128-
elif not isinstance(dtype, np.dtype):
129-
dtype = np.dtype(dtype)
130-
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
131-
dtype = np.result_type(dtype, fill_value)
132-
return dtype
133-
134-
135-
def _maybe_promote_int(dtype) -> np.dtype:
136-
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137-
# The dtype of a is used by default unless a has an integer dtype of less precision
138-
# than the default platform integer.
139-
if not isinstance(dtype, np.dtype):
140-
dtype = np.dtype(dtype)
141-
if dtype.kind == "i":
142-
dtype = np.result_type(dtype, np.intp)
143-
elif dtype.kind == "u":
144-
dtype = np.result_type(dtype, np.uintp)
145-
return dtype
146-
147-
148-
def _get_fill_value(dtype, fill_value):
149-
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150-
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
151-
return ""
152-
if fill_value == dtypes.INF or fill_value is None:
153-
return dtypes.get_pos_infinity(dtype, max_for_int=True)
154-
if fill_value == dtypes.NINF:
155-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
156-
if fill_value == dtypes.NA:
157-
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
158-
return np.nan
159-
# This is madness, but npg checks that fill_value is compatible
160-
# with array dtype even if the fill_value is never used.
161-
elif (
162-
np.issubdtype(dtype, np.integer)
163-
or np.issubdtype(dtype, np.timedelta64)
164-
or np.issubdtype(dtype, np.datetime64)
165-
):
166-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
167-
else:
168-
return None
169-
return fill_value
170-
171-
172118
def _atleast_1d(inp, min_length: int = 1):
173119
if xrutils.is_scalar(inp):
174120
inp = (inp,) * min_length
@@ -210,6 +156,7 @@ def __init__(
210156
final_dtype: DTypeLike | None = None,
211157
reduction_type: Literal["reduce", "argreduce"] = "reduce",
212158
new_dims_func: Callable | None = None,
159+
preserves_dtype: bool = False,
213160
):
214161
"""
215162
Blueprint for computing grouped aggregations.
@@ -256,6 +203,8 @@ def __init__(
256203
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
257204
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
258205
so returns (2,)
206+
preserves_dtype: bool,
207+
Whether a function preserves the dtype on return E.g. min, max, first, last, mode
259208
"""
260209
self.name = name
261210
# preprocess before blockwise
@@ -292,6 +241,7 @@ def __init__(
292241
self.new_dims_func: Callable = (
293242
returns_empty_tuple if new_dims_func is None else new_dims_func
294243
)
244+
self.preserves_dtype = preserves_dtype
295245

296246
@cached_property
297247
def new_dims(self) -> tuple[Dim]:
@@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
434384
)
435385

436386

437-
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
438-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
439-
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
440-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
387+
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
388+
nanmin = Aggregation(
389+
"nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
390+
)
391+
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
392+
nanmax = Aggregation(
393+
"nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
394+
)
441395

442396

443397
def argreduce_preprocess(array, axis):
@@ -525,10 +479,14 @@ def _pick_second(*x):
525479
final_dtype=np.intp,
526480
)
527481

528-
first = Aggregation("first", chunk=None, combine=None, fill_value=None)
529-
last = Aggregation("last", chunk=None, combine=None, fill_value=None)
530-
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA)
531-
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA)
482+
first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
483+
last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
484+
nanfirst = Aggregation(
485+
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
486+
)
487+
nanlast = Aggregation(
488+
"nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
489+
)
532490

533491
all_ = Aggregation(
534492
"all",
@@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
579537
final_dtype=np.floating,
580538
new_dims_func=quantile_new_dims_func,
581539
)
582-
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
583-
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
540+
mode = Aggregation(
541+
name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
542+
)
543+
nanmode = Aggregation(
544+
name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
545+
)
584546

585547

586548
@dataclass
@@ -634,7 +596,7 @@ def last(self) -> AlignedArrays:
634596
# TODO: automate?
635597
engine="flox",
636598
dtype=self.array.dtype,
637-
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
599+
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
638600
expected_groups=None,
639601
)
640602
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
@@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729691
binary_op=None,
730692
reduction="nanlast",
731693
scan="ffill",
694+
# Important: this must be NaN otherwise, ffill does not work.
732695
identity=np.nan,
733696
mode="concat_then_scan",
734697
)
@@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
737700
binary_op=None,
738701
reduction="nanlast",
739702
scan="ffill",
703+
# Important: this must be NaN otherwise, bfill does not work.
740704
identity=np.nan,
741705
mode="concat_then_scan",
742706
preprocess=reverse,
@@ -815,17 +779,18 @@ def _initialize_aggregation(
815779
dtype_: np.dtype | None = (
816780
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
817781
)
818-
819-
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
820-
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
821-
final_dtype = _maybe_promote_int(final_dtype)
782+
final_dtype = dtypes._normalize_dtype(
783+
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
784+
)
785+
if not agg.preserves_dtype:
786+
final_dtype = dtypes._maybe_promote_int(final_dtype)
822787
agg.dtype = {
823788
"user": dtype, # Save to automatically choose an engine
824789
"final": final_dtype,
825790
"numpy": (final_dtype,),
826791
"intermediate": tuple(
827792
(
828-
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
793+
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
829794
if int_dtype is None
830795
else np.dtype(int_dtype)
831796
)
@@ -838,10 +803,10 @@ def _initialize_aggregation(
838803
# Replace sentinel fill values according to dtype
839804
agg.fill_value["user"] = fill_value
840805
agg.fill_value["intermediate"] = tuple(
841-
_get_fill_value(dt, fv)
806+
dtypes._get_fill_value(dt, fv)
842807
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
843808
)
844-
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
809+
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])
845810

846811
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
847812
if _is_arg_reduction(agg):

‎flox/xrdtypes.py

+55
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22

33
import numpy as np
4+
from numpy.typing import DTypeLike
45

56
from . import xrutils as utils
67

@@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
147148
def is_datetime_like(dtype):
148149
"""Check if a dtype is a subclass of the numpy datetime types"""
149150
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
151+
152+
153+
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
154+
if dtype is None:
155+
dtype = array_dtype
156+
if dtype is np.floating:
157+
# mean, std, var always result in floating
158+
# but we preserve the array's dtype if it is floating
159+
if array_dtype.kind in "fcmM":
160+
dtype = array_dtype
161+
else:
162+
dtype = np.dtype("float64")
163+
elif not isinstance(dtype, np.dtype):
164+
dtype = np.dtype(dtype)
165+
if fill_value not in [None, INF, NINF, NA]:
166+
dtype = np.result_type(dtype, fill_value)
167+
return dtype
168+
169+
170+
def _maybe_promote_int(dtype) -> np.dtype:
171+
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
172+
# The dtype of a is used by default unless a has an integer dtype of less precision
173+
# than the default platform integer.
174+
if not isinstance(dtype, np.dtype):
175+
dtype = np.dtype(dtype)
176+
if dtype.kind == "i":
177+
dtype = np.result_type(dtype, np.intp)
178+
elif dtype.kind == "u":
179+
dtype = np.result_type(dtype, np.uintp)
180+
return dtype
181+
182+
183+
def _get_fill_value(dtype, fill_value):
184+
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
185+
if fill_value in [None, NA] and dtype.kind in "US":
186+
return ""
187+
if fill_value == INF or fill_value is None:
188+
return get_pos_infinity(dtype, max_for_int=True)
189+
if fill_value == NINF:
190+
return get_neg_infinity(dtype, min_for_int=True)
191+
if fill_value == NA:
192+
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
193+
return np.nan
194+
# This is madness, but npg checks that fill_value is compatible
195+
# with array dtype even if the fill_value is never used.
196+
elif np.issubdtype(dtype, np.integer):
197+
return get_neg_infinity(dtype, min_for_int=True)
198+
elif np.issubdtype(dtype, np.timedelta64):
199+
return np.timedelta64("NaT")
200+
elif np.issubdtype(dtype, np.datetime64):
201+
return np.datetime64("NaT")
202+
else:
203+
return None
204+
return fill_value

‎tests/test_core.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from numpy_groupies.aggregate_numpy import aggregate
1414

1515
import flox
16+
from flox import xrdtypes as dtypes
1617
from flox import xrutils
17-
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
18+
from flox.aggregations import Aggregation, _initialize_aggregation
1819
from flox.core import (
1920
HAS_NUMBAGG,
2021
_choose_engine,
@@ -161,7 +162,7 @@ def test_groupby_reduce(
161162
if func == "mean" or func == "nanmean":
162163
expected_result = np.array(expected, dtype=np.float64)
163164
elif func == "sum":
164-
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
165+
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
165166
elif func == "count":
166167
expected_result = np.array(expected, dtype=np.intp)
167168

@@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
389390
array = np.ones((2, 12), dtype=dtype)
390391
by = np.array([labels] * 2)
391392
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
392-
expect_dtype = _maybe_promote_int(array.dtype)
393+
expect_dtype = dtypes._maybe_promote_int(array.dtype)
393394
assert result.dtype == expect_dtype
394395

395396

@@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine):
10541055
# https://github.com/numbagg/numbagg/issues/121
10551056
pytest.skip()
10561057
if func == "sum":
1057-
expected = _maybe_promote_int(dtype)
1058+
expected = dtypes._maybe_promote_int(dtype)
10581059
elif func == "mean" and "int" in dtype:
10591060
expected = np.float64
10601061
else:
@@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
10851086
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
10861087
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))
10871088

1088-
expect_dtype = _maybe_promote_int(dtype)
1089+
expect_dtype = dtypes._maybe_promote_int(dtype)
10891090
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))
10901091

10911092

0 commit comments

Comments
 (0)
Please sign in to comment.