Skip to content

Commit 9952ed1

Browse files
committed
add device check
1 parent 4449349 commit 9952ed1

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

array_api_strict/_array_object.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,14 @@ def __getitem__(
727727
if isinstance(key, Array):
728728
key = (key,)
729729
np_key = key
730+
devices = {self.device}
730731
if isinstance(key, tuple):
732+
devices.update([subkey.device for subkey in key])
733+
if len(devices) > 1:
734+
raise ValueError(
735+
"Array indexing is only allowed when array to be indexed and all "
736+
"indexing arrays are on the same device."
737+
)
731738
# Indexing self._array with array_api_strict arrays can be erroneous
732739
# e.g., when using non-default device
733740
np_key = tuple(

array_api_strict/tests/test_array_object.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ def test_indexing_arrays(device):
107107
device = None if device is None else Device(device)
108108

109109
# 1D array
110-
a = arange(5)
110+
a = arange(5, device=device)
111111
idx = asarray([1, 0, 1, 2, -1], device=device)
112112
a_idx = a[idx]
113113

114114
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
115115
assert all(a_idx == a_idx_loop)
116116
assert a_idx.shape == idx.shape
117+
assert a.device == idx.device == a_idx.device
117118

118119
# setitem with arrays is not allowed
119120
with assert_raises(IndexError):
@@ -126,20 +127,39 @@ def test_indexing_arrays(device):
126127
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
127128
assert all(a_idx == a_idx_loop)
128129
assert a_idx.shape == idx.shape
130+
assert a.device == idx.device == a_idx.device
129131

130132
# index with two arrays
131133
a_idx = a[idx, idx]
132134
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
133135
assert all(a_idx == a_idx_loop)
134136
assert a_idx.shape == a_idx.shape
137+
assert a.device == idx.device == a_idx.device
135138

136139
# setitem with arrays is not allowed
137140
with assert_raises(IndexError):
138141
a[idx, idx] = 42
139142

140143
# smoke test indexing with ndim > 1 arrays
141144
idx = idx[..., None]
142-
a[idx, idx]
145+
a_idx = a[idx, idx]
146+
assert a.device == idx.device == a_idx.device
147+
148+
149+
def test_indexing_arrays_different_devices():
150+
# Ensure indexing via array on different device errors
151+
device1 = Device("CPU_DEVICE")
152+
device2 = Device("device1")
153+
154+
a = arange(5, device=device1)
155+
idx1 = asarray([1, 0, 1, 2, -1], device=device2)
156+
idx2 = asarray([1, 0, 1, 2, -1], device=device1)
157+
158+
with pytest.raises(ValueError, match="Array indexing is only allowed when"):
159+
a[idx1]
160+
161+
with pytest.raises(ValueError, match="Array indexing is only allowed when"):
162+
a[idx1, idx2]
143163

144164

145165
def test_promoted_scalar_inherits_device():

0 commit comments

Comments
 (0)