-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathtest_indexing_functions.py
71 lines (64 loc) · 2.11 KB
/
test_indexing_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pytest
from hypothesis import given, note
from hypothesis import strategies as st
from . import _array_module as xp
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
pytestmark = pytest.mark.ci
@pytest.mark.min_version("2022.12")
@given(
x=hh.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)),
data=st.data(),
)
def test_take(x, data):
# TODO:
# * negative axis
# * negative indices
# * different dtypes for indices
# axis is optional but only if x.ndim == 1
_axis_st = st.integers(0, max(x.ndim - 1, 0))
if x.ndim == 1:
kw = data.draw(hh.kwargs(axis=_axis_st))
else:
kw = {"axis": data.draw(_axis_st)}
axis = kw.get("axis", 0)
_indices = data.draw(
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
label="_indices",
)
indices = xp.asarray(_indices, dtype=dh.default_int)
note(f"{indices=}")
out = xp.take(x, indices, **kw)
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
"take",
out_shape=out.shape,
expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
kw=dict(
x=x,
indices=indices,
axis=axis,
),
)
out_indices = sh.ndindex(out.shape)
axis_indices = list(sh.axis_ndindex(x.shape, axis))
for axis_idx in axis_indices:
f_axis_idx = sh.fmt_idx("x", axis_idx)
for i in _indices:
f_take_idx = sh.fmt_idx(f_axis_idx, i)
indexed_x = x[axis_idx][i, ...]
for at_idx in sh.ndindex(indexed_x.shape):
out_idx = next(out_indices)
ph.assert_0d_equals(
"take",
x_repr=sh.fmt_idx(f_take_idx, at_idx),
x_val=indexed_x[at_idx],
out_repr=sh.fmt_idx("out", out_idx),
out_val=out[out_idx],
)
# sanity check
with pytest.raises(StopIteration):
next(out_indices)