Skip to content

Commit c514dbc

Browse files
committed
ENH: allow 1D integer array indices
1 parent a8f9375 commit c514dbc

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

array_api_strict/_array_object.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ def _validate_index(self, key):
395395
single_axes = []
396396
n_ellipsis = 0
397397
key_has_mask = False
398+
key_has_index_array = False
399+
key_has_slices = False
398400
for i in _key:
399401
if i is not None:
400402
nonexpanding_key.append(i)
@@ -403,13 +405,17 @@ def _validate_index(self, key):
403405
if isinstance(i, Array):
404406
if i.dtype in _boolean_dtypes:
405407
key_has_mask = True
408+
elif i.dtype in _integer_dtypes:
409+
key_has_index_array = True
406410
single_axes.append(i)
407411
else:
408412
# i must not be an array here, to avoid elementwise equals
409413
if i == Ellipsis:
410414
n_ellipsis += 1
411415
else:
412416
single_axes.append(i)
417+
if isinstance(i, slice):
418+
key_has_slices = True
413419

414420
n_single_axes = len(single_axes)
415421
if n_ellipsis > 1:
@@ -427,6 +433,12 @@ def _validate_index(self, key):
427433
"specified in the Array API."
428434
)
429435

436+
if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)):
437+
raise IndexError(
438+
"Integer index arrays are only allowed with integer indices; "
439+
f"got {key}."
440+
)
441+
430442
if n_ellipsis == 0:
431443
indexed_shape = self.shape
432444
else:
@@ -485,11 +497,11 @@ def _validate_index(self, key):
485497
if not get_array_api_strict_flags()['boolean_indexing']:
486498
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")
487499

488-
elif i.dtype in _integer_dtypes and i.ndim != 0:
500+
elif i.dtype in _integer_dtypes and i.ndim > 1:
489501
raise IndexError(
490-
f"Single-axes index {i} is a non-zero-dimensional "
491-
"integer array, but advanced integer indexing is not "
492-
"specified in the Array API."
502+
f"Single-axes index {i} is a multi-dimensional "
503+
"integer array, but advanced integer indexing is only "
504+
"specified in the Array API for 1D index arrays."
493505
)
494506
elif isinstance(i, tuple):
495507
raise IndexError(

array_api_strict/tests/test_array_object.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from .. import ones, asarray, result_type, all, equal
8+
from .. import ones, arange, reshape, asarray, result_type, all, equal
99
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
@@ -70,11 +70,25 @@ def test_validate_index():
7070
assert_raises(IndexError, lambda: a[[True, True, True]])
7171
assert_raises(IndexError, lambda: a[(True, True, True),])
7272

73-
# Integer array indices are not allowed (except for 0-D)
74-
idx = asarray([0, 1])
73+
# Integer array indices are not allowed (except for 0-D or 1D)
74+
idx = asarray([[0, 1]]) # idx.ndim == 2
7575
assert_raises(IndexError, lambda: a[idx, 0])
7676
assert_raises(IndexError, lambda: a[0, idx])
7777

78+
# Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
79+
idx = asarray([0, 1])
80+
assert_raises(IndexError, lambda: a[..., idx])
81+
assert_raises(IndexError, lambda: a[:, idx])
82+
assert_raises(IndexError, lambda: a[asarray([True, True]), idx])
83+
84+
# 1D integer array indices must have the same length
85+
idx1 = asarray([0, 1])
86+
idx2 = asarray([0, 1, 1])
87+
assert_raises(IndexError, lambda: a[idx1, idx2])
88+
89+
# Non-integer array indices are not allowed
90+
assert_raises(IndexError, lambda: a[ones(2), 0])
91+
7892
# Array-likes (lists, tuples) are not allowed as indices
7993
assert_raises(IndexError, lambda: a[[0, 1]])
8094
assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
@@ -91,6 +105,37 @@ def test_validate_index():
91105
assert_raises(IndexError, lambda: a[:])
92106
assert_raises(IndexError, lambda: a[idx])
93107

108+
109+
def test_indexing_arrays():
110+
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
111+
112+
# 1D array
113+
a = arange(5)
114+
idx = asarray([1, 0, 1, 2, -1])
115+
a_idx = a[idx]
116+
117+
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
118+
assert all(a_idx == a_idx_loop)
119+
120+
# setitem with arrays is not allowed # XXX
121+
# with assert_raises(IndexError):
122+
# a[idx] = 42
123+
124+
# mixed array and integer indexing
125+
a = reshape(arange(3*4), (3, 4))
126+
idx = asarray([1, 0, 1, 2, -1])
127+
a_idx = a[idx, 1]
128+
129+
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
130+
assert all(a_idx == a_idx_loop)
131+
132+
133+
# index with two arrays
134+
a_idx = a[idx, idx]
135+
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
136+
assert all(a_idx == a_idx_loop)
137+
138+
94139
def test_promoted_scalar_inherits_device():
95140
device1 = Device("device1")
96141
x = asarray([1., 2, 3], device=device1)

0 commit comments

Comments
 (0)