Skip to content

Commit 99f3f3b

Browse files
committed
refactor to also allow integer arrays
1 parent 3e555d7 commit 99f3f3b

File tree

1 file changed

+48
-21
lines changed

1 file changed

+48
-21
lines changed

xarray_array_testing/indexing.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import nullcontext
22

3+
import hypothesis.extra.numpy as npst
34
import hypothesis.strategies as st
45
import xarray.testing.strategies as xrst
56
from hypothesis import given
@@ -11,39 +12,65 @@ def scalar_indexer(size):
1112
return st.integers(min_value=-size, max_value=size - 1)
1213

1314

15+
def integer_array_indexer(size):
16+
dtypes = npst.integer_dtypes()
17+
18+
return npst.arrays(
19+
dtypes, size, elements={"min_value": -size, "max_value": size - 1}
20+
)
21+
22+
23+
def indexers(size, indexer_types):
24+
indexer_strategy_fns = {
25+
"scalars": scalar_indexer,
26+
"slices": st.slices,
27+
"integer_arrays": integer_array_indexer,
28+
}
29+
30+
bad_types = set(indexer_types) - indexer_strategy_fns.keys()
31+
if bad_types:
32+
raise ValueError(f"unknown indexer strategies: {sorted(bad_types)}")
33+
34+
# use the order of definition to prefer simpler strategies over more complex
35+
# ones
36+
indexer_strategies = [
37+
strategy_fn(size)
38+
for name, strategy_fn in indexer_strategy_fns.items()
39+
if name in indexer_types
40+
]
41+
return st.one_of(*indexer_strategies)
42+
43+
1444
@st.composite
15-
def indexers(draw, sizes, indexer_strategy_fn):
45+
def orthogonal_indexers(draw, sizes, indexer_types):
1646
# TODO: make use of `flatmap` and `builds` instead of `composite`
17-
possible_indexers = {dim: indexer_strategy_fn(size) for dim, size in sizes.items()}
18-
indexers = draw(xrst.unique_subset_of(possible_indexers))
19-
return {dim: draw(indexer) for dim, indexer in indexers.items()}
47+
possible_indexers = {
48+
dim: indexers(size, indexer_types) for dim, size in sizes.items()
49+
}
50+
concrete_indexers = draw(xrst.unique_subset_of(possible_indexers))
51+
return {dim: draw(indexer) for dim, indexer in concrete_indexers.items()}
2052

2153

2254
class IndexingTests(DuckArrayTestMixin):
55+
@property
56+
def orthogonal_indexer_types(self):
57+
return st.sampled_from(["scalars", "slices"])
58+
2359
@staticmethod
2460
def expected_errors(op, **parameters):
2561
return nullcontext()
2662

2763
@given(st.data())
28-
def test_variable_isel_scalars(self, data):
29-
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
30-
idx = data.draw(indexers(variable.sizes, scalar_indexer))
31-
32-
with self.expected_errors("isel_scalars", variable=variable):
33-
actual = variable.isel(idx).data
34-
35-
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
36-
expected = variable.data[*raw_indexers.values()]
37-
38-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
39-
self.assert_equal(actual, expected)
40-
41-
@given(st.data())
42-
def test_variable_isel_slices(self, data):
64+
def test_variable_isel_orthogonal(self, data):
65+
indexer_types = data.draw(
66+
st.lists(self.orthogonal_indexer_types, min_size=1, unique=True)
67+
)
4368
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
44-
idx = data.draw(indexers(variable.sizes, st.slices))
69+
idx = data.draw(orthogonal_indexers(variable.sizes, indexer_types))
4570

46-
with self.expected_errors("isel_slices", variable=variable):
71+
with self.expected_errors(
72+
"isel_orthogonal", variable=variable, indexer_types=indexer_types
73+
):
4774
actual = variable.isel(idx).data
4875

4976
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}

0 commit comments

Comments
 (0)