Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix tuple array indexing #139

Merged
merged 6 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,24 @@ def __getitem__(
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
self._validate_index(key, op="getitem")
# Indexing self._array with array_api_strict arrays can be erroneous
np_key = key._array if isinstance(key, Array) else key
if isinstance(key, Array):
key = (key,)
np_key = key
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new naming key -> np_key, meant that I've had to add this line

devices = {self.device}
if isinstance(key, tuple):
devices.update(
[subkey.device for subkey in key if hasattr(subkey, "device")]
)
if len(devices) > 1:
raise ValueError(
"Array indexing is only allowed when array to be indexed and all "
"indexing arrays are on the same device."
)
# Indexing self._array with array_api_strict arrays can be erroneous
# e.g., when using non-default device
np_key = tuple(
subkey._array if isinstance(subkey, Array) else subkey for subkey in key
)
res = self._array.__getitem__(np_key)
return self._new(res, device=self.device)

Expand Down
46 changes: 35 additions & 11 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from .. import ones, arange, reshape, asarray, result_type, all, equal
from .. import ones, arange, reshape, asarray, result_type, all, equal, stack
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
Expand Down Expand Up @@ -101,41 +101,65 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[idx])


def test_indexing_arrays():
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
def test_indexing_arrays(device):
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
device = None if device is None else Device(device)

# 1D array
a = arange(5)
idx = asarray([1, 0, 1, 2, -1])
a = arange(5, device=device)
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx]

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape
assert a.device == idx.device == a_idx.device

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx] = 42

# mixed array and integer indexing
a = reshape(arange(3*4), (3, 4))
idx = asarray([1, 0, 1, 2, -1])
a = reshape(arange(3*4, device=device), (3, 4))
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx, 1]

a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape
assert a.device == idx.device == a_idx.device

# index with two arrays
a_idx = a[idx, idx]
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == a_idx.shape
assert a.device == idx.device == a_idx.device

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx, idx] = 42

# smoke test indexing with ndim > 1 arrays
idx = idx[..., None]
a[idx, idx]
a_idx = a[idx, idx]
assert a.device == idx.device == a_idx.device


def test_indexing_arrays_different_devices():
# Ensure indexing via array on different device errors
device1 = Device("CPU_DEVICE")
device2 = Device("device1")

a = arange(5, device=device1)
idx1 = asarray([1, 0, 1, 2, -1], device=device2)
idx2 = asarray([1, 0, 1, 2, -1], device=device1)

with pytest.raises(ValueError, match="Array indexing is only allowed when"):
a[idx1]

with pytest.raises(ValueError, match="Array indexing is only allowed when"):
a[idx1, idx2]


def test_promoted_scalar_inherits_device():
Expand Down
Loading