Skip to content

Commit ca44b2a

Browse files
asmeurerhonno
authored andcommitted
Fix test_take to make axis optional when ndim == 1
I didn't explicitly test axis=None because it's not clear to me that should actually be supported, given that that's the same as axis=0.
1 parent 6a1f943 commit ca44b2a

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

array_api_tests/test_indexing_functions.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,22 @@ def test_take(x, data):
2222
# * negative axis
2323
# * negative indices
2424
# * different dtypes for indices
25-
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
25+
26+
# axis is optional but only if x.ndim == 1
27+
_axis_st = st.integers(0, max(x.ndim - 1, 0))
28+
if x.ndim == 1:
29+
kw = data.draw(hh.kwargs(axis=_axis_st))
30+
else:
31+
kw = {"axis": data.draw(_axis_st)}
32+
axis = kw.get("axis", 0)
2633
_indices = data.draw(
2734
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
2835
label="_indices",
2936
)
3037
indices = xp.asarray(_indices, dtype=dh.default_int)
3138
note(f"{indices=}")
3239

33-
out = xp.take(x, indices, axis=axis)
40+
out = xp.take(x, indices, **kw)
3441

3542
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
3643
ph.assert_shape(

0 commit comments

Comments
 (0)