3
3
import copy
4
4
import warnings
5
5
from functools import partial
6
+ from typing import Callable
6
7
7
8
import numpy as np
8
9
import numpy_groupies as npg
@@ -114,6 +115,7 @@ def __init__(
114
115
dtypes = None ,
115
116
final_dtype = None ,
116
117
reduction_type = "reduce" ,
118
+ units_func : Callable | None = None ,
117
119
):
118
120
"""
119
121
Blueprint for computing grouped aggregations.
@@ -156,6 +158,8 @@ def __init__(
156
158
per reduction in ``chunk`` as a tuple.
157
159
final_dtype : DType, optional
158
160
DType for output. By default, uses dtype of array being reduced.
161
+ units_func : pint.Unit
162
+ units for the output
159
163
"""
160
164
self .name = name
161
165
# preprocess before blockwise
@@ -187,6 +191,8 @@ def __init__(
187
191
# The following are set by _initialize_aggregation
188
192
self .finalize_kwargs = {}
189
193
self .min_count = None
194
+ self .units_func = units_func
195
+ self .units = None
190
196
191
197
def _normalize_dtype_fill_value (self , value , name ):
192
198
value = _atleast_1d (value )
@@ -235,17 +241,44 @@ def __repr__(self):
235
241
final_dtype = np .intp ,
236
242
)
237
243
244
+
245
+ def identity (x ):
246
+ return x
247
+
248
+
249
+ def square (x ):
250
+ return x ** 2
251
+
252
+
253
+ def raise_units_error (x ):
254
+ raise ValueError (
255
+ "Units cannot supported for prod in general. "
256
+ "We can only attach units when there are "
257
+ "equal number of members in each group. "
258
+ "Please strip units and then reattach units "
259
+ "to the output manually."
260
+ )
261
+
262
+
238
263
# note that the fill values are the result of np.func([np.nan, np.nan])
239
264
# final_fill_value is used for groups that don't exist. This is usually np.nan
240
- sum_ = Aggregation ("sum" , chunk = "sum" , combine = "sum" , fill_value = 0 )
241
- nansum = Aggregation ("nansum" , chunk = "nansum" , combine = "sum" , fill_value = 0 )
242
- prod = Aggregation ("prod" , chunk = "prod" , combine = "prod" , fill_value = 1 , final_fill_value = 1 )
265
+ sum_ = Aggregation ("sum" , chunk = "sum" , combine = "sum" , fill_value = 0 , units_func = identity )
266
+ nansum = Aggregation ("nansum" , chunk = "nansum" , combine = "sum" , fill_value = 0 , units_func = identity )
267
+ prod = Aggregation (
268
+ "prod" ,
269
+ chunk = "prod" ,
270
+ combine = "prod" ,
271
+ fill_value = 1 ,
272
+ final_fill_value = 1 ,
273
+ units_func = raise_units_error ,
274
+ )
243
275
nanprod = Aggregation (
244
276
"nanprod" ,
245
277
chunk = "nanprod" ,
246
278
combine = "prod" ,
247
279
fill_value = 1 ,
248
280
final_fill_value = dtypes .NA ,
281
+ units_func = raise_units_error ,
249
282
)
250
283
251
284
@@ -262,6 +295,7 @@ def _mean_finalize(sum_, count):
262
295
fill_value = (0 , 0 ),
263
296
dtypes = (None , np .intp ),
264
297
final_dtype = np .floating ,
298
+ units_func = identity ,
265
299
)
266
300
nanmean = Aggregation (
267
301
"nanmean" ,
@@ -271,6 +305,7 @@ def _mean_finalize(sum_, count):
271
305
fill_value = (0 , 0 ),
272
306
dtypes = (None , np .intp ),
273
307
final_dtype = np .floating ,
308
+ units_func = identity ,
274
309
)
275
310
276
311
@@ -296,6 +331,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
296
331
final_fill_value = np .nan ,
297
332
dtypes = (None , None , np .intp ),
298
333
final_dtype = np .floating ,
334
+ units_func = square ,
299
335
)
300
336
nanvar = Aggregation (
301
337
"nanvar" ,
@@ -306,6 +342,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
306
342
final_fill_value = np .nan ,
307
343
dtypes = (None , None , np .intp ),
308
344
final_dtype = np .floating ,
345
+ units_func = square ,
309
346
)
310
347
std = Aggregation (
311
348
"std" ,
@@ -316,6 +353,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
316
353
final_fill_value = np .nan ,
317
354
dtypes = (None , None , np .intp ),
318
355
final_dtype = np .floating ,
356
+ units_func = identity ,
319
357
)
320
358
nanstd = Aggregation (
321
359
"nanstd" ,
@@ -329,10 +367,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
329
367
)
330
368
331
369
332
- min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF )
333
- nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np .nan )
334
- max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF )
335
- nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np .nan )
370
+ min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF , units_func = identity )
371
+ nanmin = Aggregation (
372
+ "nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np .nan , units_func = identity
373
+ )
374
+ max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF , units_func = identity )
375
+ nanmax = Aggregation (
376
+ "nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np .nan , units_func = identity
377
+ )
336
378
337
379
338
380
def argreduce_preprocess (array , axis ):
@@ -420,10 +462,14 @@ def _pick_second(*x):
420
462
final_dtype = np .intp ,
421
463
)
422
464
423
- first = Aggregation ("first" , chunk = None , combine = None , fill_value = 0 )
424
- last = Aggregation ("last" , chunk = None , combine = None , fill_value = 0 )
425
- nanfirst = Aggregation ("nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = np .nan )
426
- nanlast = Aggregation ("nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = np .nan )
465
+ first = Aggregation ("first" , chunk = None , combine = None , fill_value = 0 , units_func = identity )
466
+ last = Aggregation ("last" , chunk = None , combine = None , fill_value = 0 , units_func = identity )
467
+ nanfirst = Aggregation (
468
+ "nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = np .nan , units_func = identity
469
+ )
470
+ nanlast = Aggregation (
471
+ "nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = np .nan , units_func = identity
472
+ )
427
473
428
474
all_ = Aggregation (
429
475
"all" ,
@@ -483,6 +529,7 @@ def _initialize_aggregation(
483
529
dtype ,
484
530
array_dtype ,
485
531
fill_value ,
532
+ array_units ,
486
533
min_count : int | None ,
487
534
finalize_kwargs ,
488
535
) -> Aggregation :
@@ -547,4 +594,8 @@ def _initialize_aggregation(
547
594
agg .dtype ["intermediate" ] += (np .intp ,)
548
595
agg .dtype ["numpy" ] += (np .intp ,)
549
596
597
+ if array_units is not None and agg .units_func is not None :
598
+ import pint
599
+
600
+ agg .units = agg .units_func (pint .Quantity ([1 ], units = array_units ))
550
601
return agg
0 commit comments