Skip to content

Commit d8eb4c8

Browse files
committed
Preserve input dtypes now that pandas can do it.
Closes #187
1 parent d90bf0e commit d8eb4c8

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

tests/test_core.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_groupby_reduce(
148148
)
149149
# we use pd.Index(expected_groups).to_numpy() which is always int64
150150
# for the values in this tests
151-
g_dtype = by.dtype if expected_groups is None else np.int64
151+
g_dtype = by.dtype if expected_groups is None else np.intp
152152

153153
assert_equal(groups, np.array([0, 1, 2], g_dtype))
154154
assert_equal(expected_result, result)
@@ -389,12 +389,12 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
389389
kwargs["expected_groups"] = [0, 2, 1]
390390
with raise_if_dask_computes():
391391
actual, groups = groupby_reduce(array, by, engine=engine, **kwargs, sort=False)
392-
assert_equal(groups, np.array([0, 2, 1], dtype=np.int64))
392+
assert_equal(groups, np.array([0, 2, 1], dtype=np.intp))
393393
assert_equal(expected, actual[..., [0, 2, 1]])
394394

395395
with raise_if_dask_computes():
396396
actual, groups = groupby_reduce(array, by, engine=engine, **kwargs, sort=True)
397-
assert_equal(groups, np.array([0, 1, 2], np.int64))
397+
assert_equal(groups, np.array([0, 1, 2], np.intp))
398398
assert_equal(expected, actual)
399399

400400

@@ -784,10 +784,8 @@ def test_dtype_preservation(dtype, func, engine):
784784

785785

786786
@requires_dask
787-
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
788-
@pytest.mark.parametrize(
789-
"labels_dtype", [pytest.param(np.int32, marks=pytest.mark.xfail), np.int64]
790-
)
787+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32, np.int64])
788+
@pytest.mark.parametrize("labels_dtype", [np.float32, np.float64, np.int32, np.int64])
791789
@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
792790
def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
793791
repeats = np.array([4, 4, 12, 2, 3, 4], dtype=np.int32)

0 commit comments

Comments
 (0)