Skip to content

Commit 25cc3d7

Browse files
authored
Fix indexing with integers (#146)
reviewed at #146
1 parent dc71844 commit 25cc3d7

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

array_api_strict/_array_object.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def __getitem__(
722722
devices = {self.device}
723723
if isinstance(key, tuple):
724724
devices.update(
725-
[subkey.device for subkey in key if hasattr(subkey, "device")]
725+
[subkey.device for subkey in key if isinstance(subkey, Array)]
726726
)
727727
if len(devices) > 1:
728728
raise ValueError(

array_api_strict/tests/test_array_object.py

+30
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,36 @@ def test_validate_index():
100100
assert_raises(IndexError, lambda: a[:])
101101
assert_raises(IndexError, lambda: a[idx])
102102

103+
class DummyIndex:
104+
def __init__(self, x):
105+
self.x = x
106+
def __index__(self):
107+
return self.x
108+
109+
110+
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
111+
@pytest.mark.parametrize(
112+
"integer_index",
113+
[
114+
0,
115+
np.int8(0),
116+
np.uint8(0),
117+
np.int16(0),
118+
np.uint16(0),
119+
np.int32(0),
120+
np.uint32(0),
121+
np.int64(0),
122+
np.uint64(0),
123+
DummyIndex(0),
124+
],
125+
)
126+
def test_indexing_ints(integer_index, device):
127+
# Ensure indexing with different integer types works on all Devices.
128+
device = None if device is None else Device(device)
129+
130+
a = arange(5, device=device)
131+
assert a[(integer_index,)] == a[integer_index] == a[0]
132+
103133

104134
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
105135
def test_indexing_arrays(device):

0 commit comments

Comments
 (0)