Skip to content

Commit aa358a5

Browse files
authored
Handle min_count=0 (#238)
* Hande min_count=0 * fix
1 parent 32f1ac3 commit aa358a5

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

flox/aggregations.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,15 @@ def _initialize_aggregation(
557557
assert isinstance(finalize_kwargs, dict)
558558
agg.finalize_kwargs = finalize_kwargs
559559

560+
if min_count is None:
561+
min_count = 0
562+
560563
# This is needed for the dask pathway.
561564
# Because we use intermediate fill_value since a group could be
562565
# absent in one block, but present in another block
563566
# We set it for numpy to get nansum, nanprod tests to pass
564567
# where the identity element is 0, 1
565-
if min_count is not None:
568+
if min_count > 0:
566569
agg.min_count = min_count
567570
agg.chunk += ("nanlen",)
568571
agg.numpy += ("nanlen",)
@@ -571,5 +574,7 @@ def _initialize_aggregation(
571574
agg.fill_value["numpy"] += (0,)
572575
agg.dtype["intermediate"] += (np.intp,)
573576
agg.dtype["numpy"] += (np.intp,)
577+
else:
578+
agg.min_count = 0
574579

575580
return agg

flox/core.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def _finalize_results(
849849
"""
850850
squeezed = _squeeze_results(results, axis)
851851

852-
if agg.min_count is not None:
852+
if agg.min_count > 0:
853853
counts = squeezed["intermediates"][-1]
854854
squeezed["intermediates"] = squeezed["intermediates"][:-1]
855855

@@ -860,7 +860,7 @@ def _finalize_results(
860860
else:
861861
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
862862

863-
if agg.min_count is not None:
863+
if agg.min_count > 0:
864864
count_mask = counts < agg.min_count
865865
if count_mask.any():
866866
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
@@ -1917,7 +1917,12 @@ def groupby_reduce(
19171917
min_count = 1
19181918

19191919
# TODO: set in xarray?
1920-
if min_count is not None and func in ["nansum", "nanprod"] and fill_value is None:
1920+
if (
1921+
min_count is not None
1922+
and min_count > 0
1923+
and func in ["nansum", "nanprod"]
1924+
and fill_value is None
1925+
):
19211926
# nansum, nanprod have fill_value=0, 1
19221927
# overwrite than when min_count is set
19231928
fill_value = np.nan

0 commit comments

Comments
 (0)