From ba9489ab6dfcdac18f5773f063eafa29092cfa7f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 10 Feb 2025 17:29:10 +0100 Subject: [PATCH 1/4] ENH: test indexing with arrays --- array_api_tests/test_array_object.py | 57 ++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 3e120e7e..2a38df09 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -242,6 +242,63 @@ def test_setitem_masking(shape, data): ) +@given(shape=hh.shapes(), data=st.data()) +def test_getitem_arrays_and_ints(shape, data): + assume((len(shape) > 0) and all(sh > 0 for sh in shape)) + + dtype = xp.int32 + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + + # draw a mix of ints and index arrays + arr_index = [data.draw(st.booleans()) for _ in range(len(shape))] + assume(sum(arr_index) > 0) + + # draw shapes for index arrays + if sum(arr_index) > 0: + index_shapes = data.draw( + hh.mutually_broadcastable_shapes(sum(arr_index), min_dims=1, min_side=1) + ) + index_shapes = list(index_shapes) + + # prepare the indexing tuple, a mix of integer indices and index arrays + key = [] + for i,typ in enumerate(arr_index): + if typ: + # draw an array index + this_idx = data.draw( + xps.arrays( + dtype, + shape=index_shapes.pop(), + elements=st.integers(0, shape[i]-1) + ) + ) + key.append(this_idx) + + else: + # draw an integer + key.append(data.draw(st.integers(-shape[i], shape[i]-1))) + + + print(f"??? {x.shape = } {key = } -- {[k if isinstance(k, int) else k.shape for k in key]}") + + key = tuple(key) + out = x[key] + + if sum(arr_index) > 1: + breakpoint() + + # XXX: how to properly check + import numpy as np + x_np = np.asarray(x) + out_np = np.asarray(out) + key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key) + + # print(f"{x.shape = } {out.shape = } -- {[k if isinstance(k, int) else k.shape for k in key]}") + + np.testing.assert_equal(out_np, x_np[key_np]) + + def make_scalar_casting_param( method_name: str, dtype: DataType, stype: ScalarType ) -> Param: From f55b22bd22e1194ad31ac4a90ebfb2a79c08d02f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 11 Feb 2025 17:09:26 +0000 Subject: [PATCH 2/4] . --- array_api_tests/test_array_object.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 2a38df09..55209a8c 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -285,17 +285,12 @@ def test_getitem_arrays_and_ints(shape, data): key = tuple(key) out = x[key] - if sum(arr_index) > 1: - breakpoint() - # XXX: how to properly check import numpy as np x_np = np.asarray(x) out_np = np.asarray(out) key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key) - # print(f"{x.shape = } {out.shape = } -- {[k if isinstance(k, int) else k.shape for k in key]}") - np.testing.assert_equal(out_np, x_np[key_np]) From 7105ef7f138d7745a17300b614bd3f4fc2fcc418 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 11 Feb 2025 17:12:12 +0000 Subject: [PATCH 3/4] TST: run on cupy --- array_api_tests/test_array_object.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 55209a8c..5d9916e1 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -287,9 +287,9 @@ def test_getitem_arrays_and_ints(shape, data): # XXX: how to properly check import numpy as np - x_np = np.asarray(x) - out_np = np.asarray(out) - key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key) + x_np = np.asarray(x.get()) + out_np = np.asarray(out.get()) + key_np = tuple(k if isinstance(k, int) else np.asarray(k.get()) for k in key) np.testing.assert_equal(out_np, x_np[key_np]) From b8fcb45fdbafe78c4fad0b22071dc0fa1ea4e2d8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 11 Feb 2025 17:12:18 +0000 Subject: [PATCH 4/4] Revert "TST: run on cupy" This reverts commit 7105ef7f138d7745a17300b614bd3f4fc2fcc418. --- array_api_tests/test_array_object.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 5d9916e1..55209a8c 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -287,9 +287,9 @@ def test_getitem_arrays_and_ints(shape, data): # XXX: how to properly check import numpy as np - x_np = np.asarray(x.get()) - out_np = np.asarray(out.get()) - key_np = tuple(k if isinstance(k, int) else np.asarray(k.get()) for k in key) + x_np = np.asarray(x) + out_np = np.asarray(out) + key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key) np.testing.assert_equal(out_np, x_np[key_np])