Skip to content

Commit 36a370a

Browse files
committed
ENH: fancy indexing __setitem__ is not allowed
1 parent c514dbc commit 36a370a

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

array_api_strict/_array_object.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
327327

328328
# Note: A large fraction of allowed indices are disallowed here (see the
329329
# docstring below)
330-
def _validate_index(self, key):
330+
def _validate_index(self, key, op="getitem"):
331331
"""
332332
Validate an index according to the array API.
333333
@@ -390,6 +390,9 @@ def _validate_index(self, key):
390390
"zero-dimensional integer arrays and boolean arrays "
391391
"are specified in the Array API."
392392
)
393+
if op == "setitem":
394+
if isinstance(i, Array) and i.dtype in _integer_dtypes:
395+
raise IndexError("Fancy indexing __setitem__ is not supported.")
393396

394397
nonexpanding_key = []
395398
single_axes = []
@@ -914,7 +917,7 @@ def __setitem__(
914917
"""
915918
# Note: Only indices required by the spec are allowed. See the
916919
# docstring of _validate_index
917-
self._validate_index(key)
920+
self._validate_index(key, op="setitem")
918921
if isinstance(key, Array):
919922
# Indexing self._array with array_api_strict arrays can be erroneous
920923
key = key._array

array_api_strict/tests/test_array_object.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def test_indexing_arrays():
117117
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
118118
assert all(a_idx == a_idx_loop)
119119

120-
# setitem with arrays is not allowed # XXX
121-
# with assert_raises(IndexError):
122-
# a[idx] = 42
120+
# setitem with arrays is not allowed
121+
with assert_raises(IndexError):
122+
a[idx] = 42
123123

124124
# mixed array and integer indexing
125125
a = reshape(arange(3*4), (3, 4))
@@ -129,12 +129,15 @@ def test_indexing_arrays():
129129
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
130130
assert all(a_idx == a_idx_loop)
131131

132-
133132
# index with two arrays
134133
a_idx = a[idx, idx]
135134
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
136135
assert all(a_idx == a_idx_loop)
137136

137+
# setitem with arrays is not allowed
138+
with assert_raises(IndexError):
139+
a[idx, idx] = 42
140+
138141

139142
def test_promoted_scalar_inherits_device():
140143
device1 = Device("device1")

0 commit comments

Comments
 (0)