Skip to content

fancy indexing with ints and integer arrays #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:

# Note: A large fraction of allowed indices are disallowed here (see the
# docstring below)
def _validate_index(self, key):
def _validate_index(self, key, op="getitem"):
"""
Validate an index according to the array API.

Expand Down Expand Up @@ -390,11 +390,16 @@ def _validate_index(self, key):
"zero-dimensional integer arrays and boolean arrays "
"are specified in the Array API."
)
if op == "setitem":
if isinstance(i, Array) and i.dtype in _integer_dtypes:
raise IndexError("Fancy indexing __setitem__ is not supported.")

nonexpanding_key = []
single_axes = []
n_ellipsis = 0
key_has_mask = False
key_has_index_array = False
key_has_slices = False
for i in _key:
if i is not None:
nonexpanding_key.append(i)
Expand All @@ -403,13 +408,17 @@ def _validate_index(self, key):
if isinstance(i, Array):
if i.dtype in _boolean_dtypes:
key_has_mask = True
elif i.dtype in _integer_dtypes:
key_has_index_array = True
single_axes.append(i)
else:
# i must not be an array here, to avoid elementwise equals
if i == Ellipsis:
n_ellipsis += 1
else:
single_axes.append(i)
if isinstance(i, slice):
key_has_slices = True

n_single_axes = len(single_axes)
if n_ellipsis > 1:
Expand All @@ -427,6 +436,12 @@ def _validate_index(self, key):
"specified in the Array API."
)

if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)):
raise IndexError(
"Integer index arrays are only allowed with integer indices; "
f"got {key}."
)

if n_ellipsis == 0:
indexed_shape = self.shape
else:
Expand Down Expand Up @@ -483,14 +498,9 @@ def _validate_index(self, key):
"Array API when the array is the sole index."
)
if not get_array_api_strict_flags()['boolean_indexing']:
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")

elif i.dtype in _integer_dtypes and i.ndim != 0:
raise IndexError(
f"Single-axes index {i} is a non-zero-dimensional "
"integer array, but advanced integer indexing is not "
"specified in the Array API."
)
raise RuntimeError(
"The boolean_indexing flag has been disabled for array-api-strict"
)
elif isinstance(i, tuple):
raise IndexError(
f"Single-axes index {i} is a tuple, but nested tuple "
Expand Down Expand Up @@ -902,7 +912,7 @@ def __setitem__(
"""
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
self._validate_index(key)
self._validate_index(key, op="setitem")
if isinstance(key, Array):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
Expand Down
100 changes: 75 additions & 25 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from .. import ones, asarray, result_type, all, equal
from .. import ones, arange, reshape, asarray, result_type, all, equal
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
Expand Down Expand Up @@ -45,35 +45,46 @@ def test_validate_index():
a = ones((3, 4))

# Out of bounds slices are not allowed
assert_raises(IndexError, lambda: a[:4])
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])

assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])
assert_raises(IndexError, lambda: a[:4, 0])
assert_raises(IndexError, lambda: a[:-4, 0])
assert_raises(IndexError, lambda: a[:3:-1]) # XXX raises for a wrong reason
assert_raises(IndexError, lambda: a[:-5:-1, 0])
assert_raises(IndexError, lambda: a[4:, 0])
assert_raises(IndexError, lambda: a[-4:, 0])
assert_raises(IndexError, lambda: a[4::-1, 0])
assert_raises(IndexError, lambda: a[-4::-1, 0])

assert_raises(IndexError, lambda: a[..., :5])
assert_raises(IndexError, lambda: a[..., :-5])
assert_raises(IndexError, lambda: a[..., :5:-1])
assert_raises(IndexError, lambda: a[..., :-6:-1])
assert_raises(IndexError, lambda: a[..., 5:])
assert_raises(IndexError, lambda: a[..., -5:])
assert_raises(IndexError, lambda: a[..., 5::-1])
assert_raises(IndexError, lambda: a[..., -5::-1])

# Boolean indices cannot be part of a larger tuple index
assert_raises(IndexError, lambda: a[a[:,0]==1,0])
assert_raises(IndexError, lambda: a[a[:,0]==1,...])
assert_raises(IndexError, lambda: a[..., a[0]==1])
assert_raises(IndexError, lambda: a[a[:, 0] == 1, 0])
assert_raises(IndexError, lambda: a[a[:, 0] == 1, ...])
assert_raises(IndexError, lambda: a[..., a[0] == 1])
assert_raises(IndexError, lambda: a[[True, True, True]])
assert_raises(IndexError, lambda: a[(True, True, True),])

# Integer array indices are not allowed (except for 0-D)
idx = asarray([[0, 1]])
assert_raises(IndexError, lambda: a[idx])
assert_raises(IndexError, lambda: a[idx,])
# Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
idx = asarray([0, 1])
assert_raises(IndexError, lambda: a[..., idx])
assert_raises(IndexError, lambda: a[:, idx])
assert_raises(IndexError, lambda: a[asarray([True, True]), idx])

# 1D integer array indices must have the same length
idx1 = asarray([0, 1])
idx2 = asarray([0, 1, 1])
assert_raises(IndexError, lambda: a[idx1, idx2])

# Non-integer array indices are not allowed
assert_raises(IndexError, lambda: a[ones(2), 0])

# Array-likes (lists, tuples) are not allowed as indices
assert_raises(IndexError, lambda: a[[0, 1]])
assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
assert_raises(IndexError, lambda: a[[0, 1]])
Expand All @@ -87,6 +98,45 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[0,])
assert_raises(IndexError, lambda: a[0])
assert_raises(IndexError, lambda: a[:])
assert_raises(IndexError, lambda: a[idx])


def test_indexing_arrays():
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed

# 1D array
a = arange(5)
idx = asarray([1, 0, 1, 2, -1])
a_idx = a[idx]

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx] = 42

# mixed array and integer indexing
a = reshape(arange(3*4), (3, 4))
idx = asarray([1, 0, 1, 2, -1])
a_idx = a[idx, 1]

a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# index with two arrays
a_idx = a[idx, idx]
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx, idx] = 42

# smoke test indexing with ndim > 1 arrays
idx = idx[..., None]
a[idx, idx]


def test_promoted_scalar_inherits_device():
device1 = Device("device1")
Expand Down
Loading