Skip to content

Commit 8ee7488

Browse files
committed
Pint array: strip and reattach appropriate units
Closes #163
1 parent cd6eeb5 commit 8ee7488

File tree

4 files changed

+124
-12
lines changed

4 files changed

+124
-12
lines changed

flox/aggregations.py

+62-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import warnings
55
from functools import partial
6+
from typing import Callable
67

78
import numpy as np
89
import numpy_groupies as npg
@@ -114,6 +115,7 @@ def __init__(
114115
dtypes=None,
115116
final_dtype=None,
116117
reduction_type="reduce",
118+
units_func: Callable | None = None,
117119
):
118120
"""
119121
Blueprint for computing grouped aggregations.
@@ -156,6 +158,8 @@ def __init__(
156158
per reduction in ``chunk`` as a tuple.
157159
final_dtype : DType, optional
158160
DType for output. By default, uses dtype of array being reduced.
161+
units_func : pint.Unit
162+
units for the output
159163
"""
160164
self.name = name
161165
# preprocess before blockwise
@@ -187,6 +191,8 @@ def __init__(
187191
# The following are set by _initialize_aggregation
188192
self.finalize_kwargs = {}
189193
self.min_count = None
194+
self.units_func = units_func
195+
self.units = None
190196

191197
def _normalize_dtype_fill_value(self, value, name):
192198
value = _atleast_1d(value)
@@ -235,17 +241,44 @@ def __repr__(self):
235241
final_dtype=np.intp,
236242
)
237243

244+
245+
def identity(x):
246+
return x
247+
248+
249+
def square(x):
250+
return x**2
251+
252+
253+
def raise_units_error(x):
254+
raise ValueError(
255+
"Units cannot supported for prod in general. "
256+
"We can only attach units when there are "
257+
"equal number of members in each group. "
258+
"Please strip units and then reattach units "
259+
"to the output manually."
260+
)
261+
262+
238263
# note that the fill values are the result of np.func([np.nan, np.nan])
239264
# final_fill_value is used for groups that don't exist. This is usually np.nan
240-
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
241-
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
242-
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
265+
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity)
266+
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity)
267+
prod = Aggregation(
268+
"prod",
269+
chunk="prod",
270+
combine="prod",
271+
fill_value=1,
272+
final_fill_value=1,
273+
units_func=raise_units_error,
274+
)
243275
nanprod = Aggregation(
244276
"nanprod",
245277
chunk="nanprod",
246278
combine="prod",
247279
fill_value=1,
248280
final_fill_value=dtypes.NA,
281+
units_func=raise_units_error,
249282
)
250283

251284

@@ -262,6 +295,7 @@ def _mean_finalize(sum_, count):
262295
fill_value=(0, 0),
263296
dtypes=(None, np.intp),
264297
final_dtype=np.floating,
298+
units_func=identity,
265299
)
266300
nanmean = Aggregation(
267301
"nanmean",
@@ -271,6 +305,7 @@ def _mean_finalize(sum_, count):
271305
fill_value=(0, 0),
272306
dtypes=(None, np.intp),
273307
final_dtype=np.floating,
308+
units_func=identity,
274309
)
275310

276311

@@ -296,6 +331,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
296331
final_fill_value=np.nan,
297332
dtypes=(None, None, np.intp),
298333
final_dtype=np.floating,
334+
units_func=square,
299335
)
300336
nanvar = Aggregation(
301337
"nanvar",
@@ -306,6 +342,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
306342
final_fill_value=np.nan,
307343
dtypes=(None, None, np.intp),
308344
final_dtype=np.floating,
345+
units_func=square,
309346
)
310347
std = Aggregation(
311348
"std",
@@ -316,6 +353,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
316353
final_fill_value=np.nan,
317354
dtypes=(None, None, np.intp),
318355
final_dtype=np.floating,
356+
units_func=identity,
319357
)
320358
nanstd = Aggregation(
321359
"nanstd",
@@ -329,10 +367,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
329367
)
330368

331369

332-
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
333-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
334-
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
335-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
370+
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity)
371+
nanmin = Aggregation(
372+
"nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity
373+
)
374+
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity)
375+
nanmax = Aggregation(
376+
"nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity
377+
)
336378

337379

338380
def argreduce_preprocess(array, axis):
@@ -420,10 +462,14 @@ def _pick_second(*x):
420462
final_dtype=np.intp,
421463
)
422464

423-
first = Aggregation("first", chunk=None, combine=None, fill_value=0)
424-
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
425-
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
426-
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
465+
first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity)
466+
last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity)
467+
nanfirst = Aggregation(
468+
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity
469+
)
470+
nanlast = Aggregation(
471+
"nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity
472+
)
427473

428474
all_ = Aggregation(
429475
"all",
@@ -483,6 +529,7 @@ def _initialize_aggregation(
483529
dtype,
484530
array_dtype,
485531
fill_value,
532+
array_units,
486533
min_count: int | None,
487534
finalize_kwargs,
488535
) -> Aggregation:
@@ -547,4 +594,8 @@ def _initialize_aggregation(
547594
agg.dtype["intermediate"] += (np.intp,)
548595
agg.dtype["numpy"] += (np.intp,)
549596

597+
if array_units is not None and agg.units_func is not None:
598+
import pint
599+
600+
agg.units = agg.units_func(pint.Quantity([1], units=array_units))
550601
return agg

flox/core.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
generic_aggregate,
2525
)
2626
from .cache import memoize
27+
from .pint_compat import _reattach_units, _strip_units
2728
from .xrutils import is_duck_array, is_duck_dask_array, isnull
2829

2930
if TYPE_CHECKING:
@@ -1702,6 +1703,8 @@ def groupby_reduce(
17021703
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
17031704
any_by_dask = any(by_is_dask)
17041705

1706+
array, *bys, units = _strip_units(array, *bys)
1707+
17051708
if method in ["split-reduce", "cohorts"] and any_by_dask:
17061709
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
17071710

@@ -1803,7 +1806,9 @@ def groupby_reduce(
18031806
fill_value = np.nan
18041807

18051808
kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
1806-
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
1809+
agg = _initialize_aggregation(
1810+
func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs
1811+
)
18071812

18081813
if not has_dask:
18091814
results = _reduce_blockwise(
@@ -1862,4 +1867,7 @@ def groupby_reduce(
18621867

18631868
if _is_minmax_reduction(func) and is_bool_array:
18641869
result = result.astype(bool)
1870+
1871+
units[0] = agg.units
1872+
result, *groups = _reattach_units(result, *groups, units=units)
18651873
return (result, *groups)

tests/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
except ImportError:
2525
xr_types = () # type: ignore
2626

27+
try:
28+
import pint
29+
30+
pint_types = pint.Quantity
31+
except ImportError:
32+
pint_types = () # type: ignore
33+
2734

2835
def _importorskip(modname, minversion=None):
2936
try:
@@ -46,6 +53,7 @@ def LooseVersion(vstring):
4653

4754

4855
has_dask, requires_dask = _importorskip("dask")
56+
has_pint, requires_pint = _importorskip("pint")
4957
has_xarray, requires_xarray = _importorskip("xarray")
5058

5159

@@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None):
95103
xr.testing.assert_identical(a, b)
96104
return
97105

106+
if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types):
107+
assert isinstance(a, pint_types)
108+
assert isinstance(b, pint_types)
109+
assert a.units == b.units
110+
111+
a = a.magnitude
112+
b = b.magnitude
113+
98114
if tolerance is None and (
99115
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
100116
):

tests/test_core.py

+37
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
has_dask,
3232
raise_if_dask_computes,
3333
requires_dask,
34+
requires_pint,
3435
)
3536

3637
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
@@ -1321,3 +1322,39 @@ def test_negative_index_factorize_race_condition():
13211322
for f in func
13221323
]
13231324
[dask.compute(out, scheduler="threads") for _ in range(5)]
1325+
1326+
1327+
@requires_pint
1328+
@pytest.mark.parametrize("func", ["all", "count", "sum", "var"])
1329+
@pytest.mark.parametrize("chunk", [True, False])
1330+
def test_pint(chunk, func):
1331+
import pint
1332+
1333+
if chunk:
1334+
d = dask.array.array([1, 2, 3])
1335+
else:
1336+
d = np.array([1, 2, 3])
1337+
q = pint.Quantity(d, units="m")
1338+
1339+
actual, _ = groupby_reduce(q, [0, 0, 1], func=func)
1340+
expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func)
1341+
1342+
units = None if func in ["count", "all"] else getattr(np, func)(q).units
1343+
if units is not None:
1344+
expected = pint.Quantity(expected, units=units)
1345+
assert_equal(expected, actual)
1346+
1347+
1348+
@requires_pint
1349+
@pytest.mark.parametrize("chunk", [True, False])
1350+
def test_pint_prod_error(chunk):
1351+
import pint
1352+
1353+
if chunk:
1354+
d = dask.array.array([1, 2, 3])
1355+
else:
1356+
d = np.array([1, 2, 3])
1357+
q = pint.Quantity(d, units="m")
1358+
1359+
with pytest.raises(ValueError):
1360+
groupby_reduce(q, [0, 0, 1], func="prod")

0 commit comments

Comments
 (0)