Skip to content

Commit e02c2f3

Browse files
authored
Merge branch 'main' into optimize-engine-flox
2 parents 29a0e2a + aa358a5 commit e02c2f3

13 files changed

+73
-38
lines changed

.github/workflows/ci-additional.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
run: |
7474
python -m pytest --doctest-modules flox --ignore flox/tests --cov=./ --cov-report=xml
7575
- name: Upload code coverage to Codecov
76-
uses: codecov/[email protected].1
76+
uses: codecov/[email protected].3
7777
with:
7878
file: ./coverage.xml
7979
flags: unittests
@@ -126,7 +126,7 @@ jobs:
126126
python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
127127
128128
- name: Upload mypy coverage to Codecov
129-
uses: codecov/[email protected].1
129+
uses: codecov/[email protected].3
130130
with:
131131
file: mypy_report/cobertura.xml
132132
flags: mypy

.github/workflows/ci.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
run: |
4949
pytest -n auto --cov=./ --cov-report=xml
5050
- name: Upload code coverage to Codecov
51-
uses: codecov/[email protected].1
51+
uses: codecov/[email protected].3
5252
with:
5353
file: ./coverage.xml
5454
flags: unittests
@@ -91,7 +91,7 @@ jobs:
9191
run: |
9292
python -m pytest -n auto --cov=./ --cov-report=xml
9393
- name: Upload code coverage to Codecov
94-
uses: codecov/[email protected].1
94+
uses: codecov/[email protected].3
9595
with:
9696
file: ./coverage.xml
9797
flags: unittests

.pre-commit-config.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ci:
44
repos:
55
- repo: https://github.com/charliermarsh/ruff-pre-commit
66
# Ruff version.
7-
rev: 'v0.0.246'
7+
rev: 'v0.0.260'
88
hooks:
99
- id: ruff
1010
args: ["--fix"]
@@ -18,7 +18,7 @@ repos:
1818
- id: check-docstring-first
1919

2020
- repo: https://github.com/psf/black
21-
rev: 23.1.0
21+
rev: 23.3.0
2222
hooks:
2323
- id: black
2424

@@ -31,7 +31,7 @@ repos:
3131
- mdformat-myst
3232

3333
- repo: https://github.com/nbQA-dev/nbQA
34-
rev: 1.6.1
34+
rev: 1.7.0
3535
hooks:
3636
- id: nbqa-black
3737
- id: nbqa-ruff
@@ -44,13 +44,13 @@ repos:
4444
args: [--extra-keys=metadata.kernelspec metadata.language_info.version]
4545

4646
- repo: https://github.com/codespell-project/codespell
47-
rev: v2.2.2
47+
rev: v2.2.4
4848
hooks:
4949
- id: codespell
5050
additional_dependencies:
5151
- tomli
5252

5353
- repo: https://github.com/abravalheri/validate-pyproject
54-
rev: v0.12.1
54+
rev: v0.12.2
5555
hooks:
5656
- id: validate-pyproject

ci/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies:
1414
- pip
1515
- pytest
1616
- pytest-cov
17+
- pytest-pretty
1718
- pytest-xdist
1819
- xarray
1920
- pre-commit

ci/minimal-requirements.yml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- pip
88
- pytest
99
- pytest-cov
10+
- pytest-pretty
1011
- pytest-xdist
1112
- numpy==1.20
1213
- numpy_groupies==0.9.19

ci/no-dask.yml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- pip
1010
- pytest
1111
- pytest-cov
12+
- pytest-pretty
1213
- pytest-xdist
1314
- xarray
1415
- numpydoc

ci/no-xarray.yml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- pip
1010
- pytest
1111
- pytest-cov
12+
- pytest-pretty
1213
- pytest-xdist
1314
- dask-core
1415
- numpydoc

ci/upstream-dev-env.yml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- numba
1111
- pytest
1212
- pytest-cov
13+
- pytest-pretty
1314
- pytest-xdist
1415
- pip
1516
- pip:

flox/aggregations.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,15 @@ def _initialize_aggregation(
557557
assert isinstance(finalize_kwargs, dict)
558558
agg.finalize_kwargs = finalize_kwargs
559559

560+
if min_count is None:
561+
min_count = 0
562+
560563
# This is needed for the dask pathway.
561564
# Because we use intermediate fill_value since a group could be
562565
# absent in one block, but present in another block
563566
# We set it for numpy to get nansum, nanprod tests to pass
564567
# where the identity element is 0, 1
565-
if min_count is not None:
568+
if min_count > 0:
566569
agg.min_count = min_count
567570
agg.chunk += ("nanlen",)
568571
agg.numpy += ("nanlen",)
@@ -571,5 +574,7 @@ def _initialize_aggregation(
571574
agg.fill_value["numpy"] += (0,)
572575
agg.dtype["intermediate"] += (np.intp,)
573576
agg.dtype["numpy"] += (np.intp,)
577+
else:
578+
agg.min_count = 0
574579

575580
return agg

flox/core.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def rechunk_for_cohorts(
354354
def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray:
355355
"""
356356
Rechunks array so that group boundaries line up with chunk boundaries, allowing
357-
embarassingly parallel group reductions.
357+
embarrassingly parallel group reductions.
358358
359359
This only works when the groups are sequential
360360
(e.g. labels = ``[0,0,0,1,1,1,1,2,2]``).
@@ -849,7 +849,7 @@ def _finalize_results(
849849
"""
850850
squeezed = _squeeze_results(results, axis)
851851

852-
if agg.min_count is not None:
852+
if agg.min_count > 0:
853853
counts = squeezed["intermediates"][-1]
854854
squeezed["intermediates"] = squeezed["intermediates"][:-1]
855855

@@ -860,7 +860,7 @@ def _finalize_results(
860860
else:
861861
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
862862

863-
if agg.min_count is not None:
863+
if agg.min_count > 0:
864864
count_mask = counts < agg.min_count
865865
if count_mask.any():
866866
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
@@ -1598,7 +1598,11 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:
15981598

15991599

16001600
def _factorize_multiple(
1601-
by: T_Bys, expected_groups: T_ExpectIndexTuple, any_by_dask: bool, reindex: bool
1601+
by: T_Bys,
1602+
expected_groups: T_ExpectIndexTuple,
1603+
any_by_dask: bool,
1604+
reindex: bool,
1605+
sort: bool = True,
16021606
) -> tuple[tuple[np.ndarray], tuple[np.ndarray, ...], tuple[int, ...]]:
16031607
if any_by_dask:
16041608
import dask.array
@@ -1617,6 +1621,7 @@ def _factorize_multiple(
16171621
expected_groups=expected_groups,
16181622
fastpath=True,
16191623
reindex=reindex,
1624+
sort=sort,
16201625
)
16211626

16221627
fg, gs = [], []
@@ -1643,6 +1648,7 @@ def _factorize_multiple(
16431648
expected_groups=expected_groups,
16441649
fastpath=True,
16451650
reindex=reindex,
1651+
sort=sort,
16461652
)
16471653

16481654
return (group_idx,), found_groups, grp_shape
@@ -1653,10 +1659,16 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
16531659
return (None,) * nby
16541660

16551661
if nby == 1 and not isinstance(expected_groups, tuple):
1656-
if isinstance(expected_groups, pd.Index):
1662+
if isinstance(expected_groups, (pd.Index, np.ndarray)):
16571663
return (expected_groups,)
16581664
else:
1659-
return (np.asarray(expected_groups),)
1665+
array = np.asarray(expected_groups)
1666+
if np.issubdtype(array.dtype, np.integer):
1667+
# preserve default dtypes
1668+
# on pandas 1.5/2, on windows
1669+
# when a list is passed
1670+
array = array.astype(np.int64)
1671+
return (array,)
16601672

16611673
if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
16621674
raise ValueError(
@@ -1833,21 +1845,28 @@ def groupby_reduce(
18331845
# (pd.IntervalIndex or not)
18341846
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
18351847

1836-
is_binning = any([isinstance(e, pd.IntervalIndex) for e in expected_groups])
1837-
1838-
# TODO: could restrict this to dask-only
1839-
factorize_early = (nby > 1) or (
1840-
is_binning and method == "cohorts" and is_duck_dask_array(array)
1848+
# Don't factorize "early only when
1849+
# grouping by dask arrays, and not having expected_groups
1850+
factorize_early = not (
1851+
# can't do it if we are grouping by dask array but don't have expected_groups
1852+
any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups))
18411853
)
18421854
if factorize_early:
18431855
bys, final_groups, grp_shape = _factorize_multiple(
1844-
bys, expected_groups, any_by_dask=any_by_dask, reindex=reindex
1856+
bys,
1857+
expected_groups,
1858+
any_by_dask=any_by_dask,
1859+
# This is the only way it makes sense I think.
1860+
# reindex controls what's actually allocated in chunk_reduce
1861+
# At this point, we care about an accurate conversion to codes.
1862+
reindex=True,
1863+
sort=sort,
18451864
)
18461865
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)
18471866

18481867
assert len(bys) == 1
1849-
by_ = bys[0]
1850-
expected_groups = expected_groups[0]
1868+
(by_,) = bys
1869+
(expected_groups,) = expected_groups
18511870

18521871
if axis is None:
18531872
axis_ = tuple(array.ndim + np.arange(-by_.ndim, 0))
@@ -1898,7 +1917,12 @@ def groupby_reduce(
18981917
min_count = 1
18991918

19001919
# TODO: set in xarray?
1901-
if min_count is not None and func in ["nansum", "nanprod"] and fill_value is None:
1920+
if (
1921+
min_count is not None
1922+
and min_count > 0
1923+
and func in ["nansum", "nanprod"]
1924+
and fill_value is None
1925+
):
19021926
# nansum, nanprod have fill_value=0, 1
19031927
# overwrite than when min_count is set
19041928
fill_value = np.nan

flox/xarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def rechunk_for_cohorts(
534534
def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray):
535535
"""
536536
Rechunks array so that group boundaries line up with chunk boundaries, allowing
537-
embarassingly parallel group reductions.
537+
embarrassingly parallel group reductions.
538538
539539
This only works when the groups are sequential
540540
(e.g. labels = ``[0,0,0,1,1,1,1,2,2]``).

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33

4-
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
4+
@pytest.fixture(scope="module", params=["flox"])
55
def engine(request):
66
if request.param == "numba":
77
try:

tests/test_core.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def test_alignment_error():
8888

8989
@pytest.mark.parametrize("dtype", (float, int))
9090
@pytest.mark.parametrize("chunk", [False, True])
91-
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
91+
# TODO: make this intp when python 3.8 is dropped
92+
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2], dtype=np.int64)])
9293
@pytest.mark.parametrize(
9394
"func, array, by, expected",
9495
[
@@ -148,7 +149,12 @@ def test_groupby_reduce(
148149
)
149150
# we use pd.Index(expected_groups).to_numpy() which is always int64
150151
# for the values in this tests
151-
g_dtype = by.dtype if expected_groups is None else np.int64
152+
if expected_groups is None:
153+
g_dtype = by.dtype
154+
elif isinstance(expected_groups, np.ndarray):
155+
g_dtype = expected_groups.dtype
156+
else:
157+
g_dtype = np.int64
152158

153159
assert_equal(groups, np.array([0, 1, 2], g_dtype))
154160
assert_equal(expected_result, result)
@@ -653,7 +659,7 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
653659
array = [1, 1, 1, 1, 1, 1]
654660
labels = [0.2, 1.5, 1.9, 2, 3, 20]
655661

656-
if method in ["split-reduce", "cohorts"] and chunk_labels:
662+
if method == "cohorts" and chunk_labels:
657663
pytest.xfail()
658664

659665
if chunks:
@@ -784,10 +790,8 @@ def test_dtype_preservation(dtype, func, engine):
784790

785791

786792
@requires_dask
787-
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
788-
@pytest.mark.parametrize(
789-
"labels_dtype", [pytest.param(np.int32, marks=pytest.mark.xfail), np.int64]
790-
)
793+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32, np.int64])
794+
@pytest.mark.parametrize("labels_dtype", [np.float32, np.float64, np.int32, np.int64])
791795
@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
792796
def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
793797
repeats = np.array([4, 4, 12, 2, 3, 4], dtype=np.int32)
@@ -836,10 +840,7 @@ def test_cohorts_nd_by(func, method, axis, engine):
836840
assert_equal(actual, expected)
837841

838842
actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
839-
if method == "map-reduce":
840-
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
841-
else:
842-
assert_equal(groups, np.array([1, 30, 2, 31, 3, 40, 4], dtype=np.int64))
843+
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
843844
reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
844845
assert_equal(reindexed, expected)
845846

0 commit comments

Comments
 (0)