Skip to content

Commit 64d3d67

Browse files
authored
indexing tests (#14)
* check the indexing behavior for scalars * refactor the indexers strategy * add a test for slices * comment on future refactoring work * configure coverage * refactor to also allow integer arrays
1 parent 4b4c993 commit 64d3d67

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

pyproject.toml

+8
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,11 @@ known-third-party = []
6363

6464
[tool.ruff.lint.flake8-tidy-imports]
6565
ban-relative-imports = "all"
66+
67+
[tool.coverage.run]
68+
source = ["xarray_array_testing"]
69+
branch = true
70+
71+
[tool.coverage.report]
72+
show_missing = true
73+
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]

xarray_array_testing/indexing.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from contextlib import nullcontext
2+
3+
import hypothesis.extra.numpy as npst
4+
import hypothesis.strategies as st
5+
import xarray.testing.strategies as xrst
6+
from hypothesis import given
7+
8+
from xarray_array_testing.base import DuckArrayTestMixin
9+
10+
11+
def scalar_indexer(size):
12+
return st.integers(min_value=-size, max_value=size - 1)
13+
14+
15+
def integer_array_indexer(size):
16+
dtypes = npst.integer_dtypes()
17+
18+
return npst.arrays(
19+
dtypes, size, elements={"min_value": -size, "max_value": size - 1}
20+
)
21+
22+
23+
def indexers(size, indexer_types):
24+
indexer_strategy_fns = {
25+
"scalars": scalar_indexer,
26+
"slices": st.slices,
27+
"integer_arrays": integer_array_indexer,
28+
}
29+
30+
bad_types = set(indexer_types) - indexer_strategy_fns.keys()
31+
if bad_types:
32+
raise ValueError(f"unknown indexer strategies: {sorted(bad_types)}")
33+
34+
# use the order of definition to prefer simpler strategies over more complex
35+
# ones
36+
indexer_strategies = [
37+
strategy_fn(size)
38+
for name, strategy_fn in indexer_strategy_fns.items()
39+
if name in indexer_types
40+
]
41+
return st.one_of(*indexer_strategies)
42+
43+
44+
@st.composite
45+
def orthogonal_indexers(draw, sizes, indexer_types):
46+
# TODO: make use of `flatmap` and `builds` instead of `composite`
47+
possible_indexers = {
48+
dim: indexers(size, indexer_types) for dim, size in sizes.items()
49+
}
50+
concrete_indexers = draw(xrst.unique_subset_of(possible_indexers))
51+
return {dim: draw(indexer) for dim, indexer in concrete_indexers.items()}
52+
53+
54+
class IndexingTests(DuckArrayTestMixin):
55+
@property
56+
def orthogonal_indexer_types(self):
57+
return st.sampled_from(["scalars", "slices"])
58+
59+
@staticmethod
60+
def expected_errors(op, **parameters):
61+
return nullcontext()
62+
63+
@given(st.data())
64+
def test_variable_isel_orthogonal(self, data):
65+
indexer_types = data.draw(
66+
st.lists(self.orthogonal_indexer_types, min_size=1, unique=True)
67+
)
68+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
69+
idx = data.draw(orthogonal_indexers(variable.sizes, indexer_types))
70+
71+
with self.expected_errors(
72+
"isel_orthogonal", variable=variable, indexer_types=indexer_types
73+
):
74+
actual = variable.isel(idx).data
75+
76+
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
77+
expected = variable.data[*raw_indexers.values()]
78+
79+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
80+
self.assert_equal(actual, expected)

xarray_array_testing/tests/test_numpy.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from xarray_array_testing.base import DuckArrayTestMixin
77
from xarray_array_testing.creation import CreationTests
8+
from xarray_array_testing.indexing import IndexingTests
89
from xarray_array_testing.reduction import ReductionTests
910

1011

@@ -32,3 +33,7 @@ class TestCreationNumpy(CreationTests, NumpyTestMixin):
3233

3334
class TestReductionNumpy(ReductionTests, NumpyTestMixin):
3435
pass
36+
37+
38+
class TestIndexingNumpy(IndexingTests, NumpyTestMixin):
39+
pass

0 commit comments

Comments
 (0)