Skip to content

Commit 32f1ac3

Browse files
committed
Try and fix dtypes on 3.8,3.10 windows
When passed a list always upcast to int64. This makes it work n pandas 1.5 which is what we get on py3.8 because of xarray's pinning
1 parent d8eb4c8 commit 32f1ac3

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

Diff for: flox/core.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1659,10 +1659,16 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
16591659
return (None,) * nby
16601660

16611661
if nby == 1 and not isinstance(expected_groups, tuple):
1662-
if isinstance(expected_groups, pd.Index):
1662+
if isinstance(expected_groups, (pd.Index, np.ndarray)):
16631663
return (expected_groups,)
16641664
else:
1665-
return (np.asarray(expected_groups),)
1665+
array = np.asarray(expected_groups)
1666+
if np.issubdtype(array.dtype, np.integer):
1667+
# preserve default dtypes
1668+
# on pandas 1.5/2, on windows
1669+
# when a list is passed
1670+
array = array.astype(np.int64)
1671+
return (array,)
16661672

16671673
if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
16681674
raise ValueError(

Diff for: tests/test_core.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def test_alignment_error():
8888

8989
@pytest.mark.parametrize("dtype", (float, int))
9090
@pytest.mark.parametrize("chunk", [False, True])
91-
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
91+
# TODO: make this intp when python 3.8 is dropped
92+
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2], dtype=np.int64)])
9293
@pytest.mark.parametrize(
9394
"func, array, by, expected",
9495
[
@@ -148,7 +149,12 @@ def test_groupby_reduce(
148149
)
149150
# we use pd.Index(expected_groups).to_numpy() which is always int64
150151
# for the values in this tests
151-
g_dtype = by.dtype if expected_groups is None else np.intp
152+
if expected_groups is None:
153+
g_dtype = by.dtype
154+
elif isinstance(expected_groups, np.ndarray):
155+
g_dtype = expected_groups.dtype
156+
else:
157+
g_dtype = np.int64
152158

153159
assert_equal(groups, np.array([0, 1, 2], g_dtype))
154160
assert_equal(expected_result, result)
@@ -389,12 +395,12 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
389395
kwargs["expected_groups"] = [0, 2, 1]
390396
with raise_if_dask_computes():
391397
actual, groups = groupby_reduce(array, by, engine=engine, **kwargs, sort=False)
392-
assert_equal(groups, np.array([0, 2, 1], dtype=np.intp))
398+
assert_equal(groups, np.array([0, 2, 1], dtype=np.int64))
393399
assert_equal(expected, actual[..., [0, 2, 1]])
394400

395401
with raise_if_dask_computes():
396402
actual, groups = groupby_reduce(array, by, engine=engine, **kwargs, sort=True)
397-
assert_equal(groups, np.array([0, 1, 2], np.intp))
403+
assert_equal(groups, np.array([0, 1, 2], np.int64))
398404
assert_equal(expected, actual)
399405

400406

0 commit comments

Comments
 (0)