@@ -115,60 +115,6 @@ def generic_aggregate(
115
115
return result
116
116
117
117
118
- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119
- if dtype is None :
120
- dtype = array_dtype
121
- if dtype is np .floating :
122
- # mean, std, var always result in floating
123
- # but we preserve the array's dtype if it is floating
124
- if array_dtype .kind in "fcmM" :
125
- dtype = array_dtype
126
- else :
127
- dtype = np .dtype ("float64" )
128
- elif not isinstance (dtype , np .dtype ):
129
- dtype = np .dtype (dtype )
130
- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131
- dtype = np .result_type (dtype , fill_value )
132
- return dtype
133
-
134
-
135
- def _maybe_promote_int (dtype ) -> np .dtype :
136
- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137
- # The dtype of a is used by default unless a has an integer dtype of less precision
138
- # than the default platform integer.
139
- if not isinstance (dtype , np .dtype ):
140
- dtype = np .dtype (dtype )
141
- if dtype .kind == "i" :
142
- dtype = np .result_type (dtype , np .intp )
143
- elif dtype .kind == "u" :
144
- dtype = np .result_type (dtype , np .uintp )
145
- return dtype
146
-
147
-
148
- def _get_fill_value (dtype , fill_value ):
149
- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150
- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151
- return ""
152
- if fill_value == dtypes .INF or fill_value is None :
153
- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154
- if fill_value == dtypes .NINF :
155
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156
- if fill_value == dtypes .NA :
157
- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158
- return np .nan
159
- # This is madness, but npg checks that fill_value is compatible
160
- # with array dtype even if the fill_value is never used.
161
- elif (
162
- np .issubdtype (dtype , np .integer )
163
- or np .issubdtype (dtype , np .timedelta64 )
164
- or np .issubdtype (dtype , np .datetime64 )
165
- ):
166
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
167
- else :
168
- return None
169
- return fill_value
170
-
171
-
172
118
def _atleast_1d (inp , min_length : int = 1 ):
173
119
if xrutils .is_scalar (inp ):
174
120
inp = (inp ,) * min_length
@@ -210,6 +156,7 @@ def __init__(
210
156
final_dtype : DTypeLike | None = None ,
211
157
reduction_type : Literal ["reduce" , "argreduce" ] = "reduce" ,
212
158
new_dims_func : Callable | None = None ,
159
+ preserves_dtype : bool = False ,
213
160
):
214
161
"""
215
162
Blueprint for computing grouped aggregations.
@@ -256,6 +203,8 @@ def __init__(
256
203
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
257
204
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
258
205
so returns (2,)
206
+ preserves_dtype: bool,
207
+ Whether a function preserves the dtype on return E.g. min, max, first, last, mode
259
208
"""
260
209
self .name = name
261
210
# preprocess before blockwise
@@ -292,6 +241,7 @@ def __init__(
292
241
self .new_dims_func : Callable = (
293
242
returns_empty_tuple if new_dims_func is None else new_dims_func
294
243
)
244
+ self .preserves_dtype = preserves_dtype
295
245
296
246
@cached_property
297
247
def new_dims (self ) -> tuple [Dim ]:
@@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
434
384
)
435
385
436
386
437
- min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF )
438
- nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np .nan )
439
- max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF )
440
- nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np .nan )
387
+ min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF , preserves_dtype = True )
388
+ nanmin = Aggregation (
389
+ "nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = dtypes .NA , preserves_dtype = True
390
+ )
391
+ max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF , preserves_dtype = True )
392
+ nanmax = Aggregation (
393
+ "nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = dtypes .NA , preserves_dtype = True
394
+ )
441
395
442
396
443
397
def argreduce_preprocess (array , axis ):
@@ -525,10 +479,14 @@ def _pick_second(*x):
525
479
final_dtype = np .intp ,
526
480
)
527
481
528
- first = Aggregation ("first" , chunk = None , combine = None , fill_value = None )
529
- last = Aggregation ("last" , chunk = None , combine = None , fill_value = None )
530
- nanfirst = Aggregation ("nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = dtypes .NA )
531
- nanlast = Aggregation ("nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = dtypes .NA )
482
+ first = Aggregation ("first" , chunk = None , combine = None , fill_value = None , preserves_dtype = True )
483
+ last = Aggregation ("last" , chunk = None , combine = None , fill_value = None , preserves_dtype = True )
484
+ nanfirst = Aggregation (
485
+ "nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = dtypes .NA , preserves_dtype = True
486
+ )
487
+ nanlast = Aggregation (
488
+ "nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = dtypes .NA , preserves_dtype = True
489
+ )
532
490
533
491
all_ = Aggregation (
534
492
"all" ,
@@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
579
537
final_dtype = np .floating ,
580
538
new_dims_func = quantile_new_dims_func ,
581
539
)
582
- mode = Aggregation (name = "mode" , fill_value = dtypes .NA , chunk = None , combine = None )
583
- nanmode = Aggregation (name = "nanmode" , fill_value = dtypes .NA , chunk = None , combine = None )
540
+ mode = Aggregation (
541
+ name = "mode" , fill_value = dtypes .NA , chunk = None , combine = None , preserves_dtype = True
542
+ )
543
+ nanmode = Aggregation (
544
+ name = "nanmode" , fill_value = dtypes .NA , chunk = None , combine = None , preserves_dtype = True
545
+ )
584
546
585
547
586
548
@dataclass
@@ -634,7 +596,7 @@ def last(self) -> AlignedArrays:
634
596
# TODO: automate?
635
597
engine = "flox" ,
636
598
dtype = self .array .dtype ,
637
- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
599
+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
638
600
expected_groups = None ,
639
601
)
640
602
return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729
691
binary_op = None ,
730
692
reduction = "nanlast" ,
731
693
scan = "ffill" ,
694
+ # Important: this must be NaN otherwise, ffill does not work.
732
695
identity = np .nan ,
733
696
mode = "concat_then_scan" ,
734
697
)
@@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
737
700
binary_op = None ,
738
701
reduction = "nanlast" ,
739
702
scan = "ffill" ,
703
+ # Important: this must be NaN otherwise, bfill does not work.
740
704
identity = np .nan ,
741
705
mode = "concat_then_scan" ,
742
706
preprocess = reverse ,
@@ -815,17 +779,18 @@ def _initialize_aggregation(
815
779
dtype_ : np .dtype | None = (
816
780
np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
817
781
)
818
-
819
- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
820
- if agg .name not in ["first" , "last" , "nanfirst" , "nanlast" , "min" , "max" , "nanmin" , "nanmax" ]:
821
- final_dtype = _maybe_promote_int (final_dtype )
782
+ final_dtype = dtypes ._normalize_dtype (
783
+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
784
+ )
785
+ if not agg .preserves_dtype :
786
+ final_dtype = dtypes ._maybe_promote_int (final_dtype )
822
787
agg .dtype = {
823
788
"user" : dtype , # Save to automatically choose an engine
824
789
"final" : final_dtype ,
825
790
"numpy" : (final_dtype ,),
826
791
"intermediate" : tuple (
827
792
(
828
- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
793
+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
829
794
if int_dtype is None
830
795
else np .dtype (int_dtype )
831
796
)
@@ -838,10 +803,10 @@ def _initialize_aggregation(
838
803
# Replace sentinel fill values according to dtype
839
804
agg .fill_value ["user" ] = fill_value
840
805
agg .fill_value ["intermediate" ] = tuple (
841
- _get_fill_value (dt , fv )
806
+ dtypes . _get_fill_value (dt , fv )
842
807
for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
843
808
)
844
- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
809
+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
845
810
846
811
fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
847
812
if _is_arg_reduction (agg ):
0 commit comments