From 001ace3962aabcdf312ae7344ec6fce5104addbf Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 31 Mar 2025 22:53:43 +1100 Subject: [PATCH 1/5] fix --- array_api_strict/_array_object.py | 7 ++++++- array_api_strict/tests/test_array_object.py | 22 ++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0595594..91eb82c 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -698,8 +698,13 @@ def __getitem__( # docstring of _validate_index self._validate_index(key) if isinstance(key, Array): + key = (key,) + if isinstance(key, tuple): # Indexing self._array with array_api_strict arrays can be erroneous - key = key._array + # e.g., when using non-default device + key = tuple( + subkey._array if isinstance(subkey, Array) else subkey for subkey in key + ) res = self._array.__getitem__(key) return self._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e24a40f..f017261 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from .. import ones, arange, reshape, asarray, result_type, all, equal +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, @@ -101,33 +101,37 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[idx]) -def test_indexing_arrays(): +# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"]) +def test_indexing_arrays(device='device1'): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed + device = Device(device) # 1D array a = arange(5) - idx = asarray([1, 0, 1, 2, -1]) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx] - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape # setitem with arrays is not allowed with assert_raises(IndexError): a[idx] = 42 # mixed array and integer indexing - a = reshape(arange(3*4), (3, 4)) - idx = asarray([1, 0, 1, 2, -1]) + a = reshape(arange(3*4, device=device), (3, 4)) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx, 1] - - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape # index with two arrays a_idx = a[idx, idx] - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == a_idx.shape # setitem with arrays is not allowed with assert_raises(IndexError): From 9c887e8610be1493213427abd56435c0f47a1d04 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 12:13:23 +1100 Subject: [PATCH 2/5] fix param --- array_api_strict/tests/test_array_object.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index f017261..897ef14 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -101,10 +101,10 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[idx]) -# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"]) -def test_indexing_arrays(device='device1'): +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +def test_indexing_arrays(): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed - device = Device(device) + device = None if device is None else Device(device) # 1D array a = arange(5) From 4449349b423e738b9de8bea8145b27bee44b1911 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 12:25:59 +1100 Subject: [PATCH 3/5] fix typos --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index b0d6a49..9edbd77 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -733,7 +733,7 @@ def __getitem__( np_key = tuple( subkey._array if isinstance(subkey, Array) else subkey for subkey in key ) - res = self._array.__getitem__(key) + res = self._array.__getitem__(np_key) return self._new(res, device=self.device) def __gt__(self, other: Array | int | float, /) -> Array: diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 897ef14..a233e44 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -102,7 +102,7 @@ def test_validate_index(): @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) -def test_indexing_arrays(): +def test_indexing_arrays(device): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed device = None if device is None else Device(device) From 9952ed11cfc040d54bac15a07e40213436233409 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 14:30:10 +1100 Subject: [PATCH 4/5] add device check --- array_api_strict/_array_object.py | 7 ++++++ array_api_strict/tests/test_array_object.py | 24 +++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9edbd77..05e292c 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -727,7 +727,14 @@ def __getitem__( if isinstance(key, Array): key = (key,) np_key = key + devices = {self.device} if isinstance(key, tuple): + devices.update([subkey.device for subkey in key]) + if len(devices) > 1: + raise ValueError( + "Array indexing is only allowed when array to be indexed and all " + "indexing arrays are on the same device." + ) # Indexing self._array with array_api_strict arrays can be erroneous # e.g., when using non-default device np_key = tuple( diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a233e44..51f4f31 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -107,13 +107,14 @@ def test_indexing_arrays(device): device = None if device is None else Device(device) # 1D array - a = arange(5) + a = arange(5, device=device) idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx] a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): @@ -126,12 +127,14 @@ def test_indexing_arrays(device): a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # index with two arrays a_idx = a[idx, idx] a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) assert a_idx.shape == a_idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): @@ -139,7 +142,24 @@ def test_indexing_arrays(device): # smoke test indexing with ndim > 1 arrays idx = idx[..., None] - a[idx, idx] + a_idx = a[idx, idx] + assert a.device == idx.device == a_idx.device + + +def test_indexing_arrays_different_devices(): + # Ensure indexing via array on different device errors + device1 = Device("CPU_DEVICE") + device2 = Device("device1") + + a = arange(5, device=device1) + idx1 = asarray([1, 0, 1, 2, -1], device=device2) + idx2 = asarray([1, 0, 1, 2, -1], device=device1) + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1] + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1, idx2] def test_promoted_scalar_inherits_device(): From 65e630a548846cbe50cb7ae3d5704a42096d7e6b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 14:34:07 +1100 Subject: [PATCH 5/5] fix --- array_api_strict/_array_object.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 05e292c..823274e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -729,7 +729,9 @@ def __getitem__( np_key = key devices = {self.device} if isinstance(key, tuple): - devices.update([subkey.device for subkey in key]) + devices.update( + [subkey.device for subkey in key if hasattr(subkey, "device")] + ) if len(devices) > 1: raise ValueError( "Array indexing is only allowed when array to be indexed and all "