diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 3e120e7e..55209a8c 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -242,6 +242,58 @@ def test_setitem_masking(shape, data): ) +@given(shape=hh.shapes(), data=st.data()) +def test_getitem_arrays_and_ints(shape, data): + assume((len(shape) > 0) and all(sh > 0 for sh in shape)) + + dtype = xp.int32 + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + + # draw a mix of ints and index arrays + arr_index = [data.draw(st.booleans()) for _ in range(len(shape))] + assume(sum(arr_index) > 0) + + # draw shapes for index arrays + if sum(arr_index) > 0: + index_shapes = data.draw( + hh.mutually_broadcastable_shapes(sum(arr_index), min_dims=1, min_side=1) + ) + index_shapes = list(index_shapes) + + # prepare the indexing tuple, a mix of integer indices and index arrays + key = [] + for i,typ in enumerate(arr_index): + if typ: + # draw an array index + this_idx = data.draw( + xps.arrays( + dtype, + shape=index_shapes.pop(), + elements=st.integers(0, shape[i]-1) + ) + ) + key.append(this_idx) + + else: + # draw an integer + key.append(data.draw(st.integers(-shape[i], shape[i]-1))) + + + print(f"??? {x.shape = } {key = } -- {[k if isinstance(k, int) else k.shape for k in key]}") + + key = tuple(key) + out = x[key] + + # XXX: how to properly check + import numpy as np + x_np = np.asarray(x) + out_np = np.asarray(out) + key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key) + + np.testing.assert_equal(out_np, x_np[key_np]) + + def make_scalar_casting_param( method_name: str, dtype: DataType, stype: ScalarType ) -> Param: