diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 5102ccbdcb..e0a8145056 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -771,6 +771,7 @@ jobs: cd /home/runner/work/array-api-tests ${CONDA_PREFIX}/bin/python -c "import dpctl; dpctl.lsplatform()" export ARRAY_API_TESTS_MODULE=dpctl.tensor + export ARRAY_API_TESTS_VERSION=2024.12 ${CONDA_PREFIX}/bin/python -m pytest --json-report --json-report-file=$FILE --disable-deadline --skips-file ${GITHUB_WORKSPACE}/.github/workflows/array-api-skips.txt array_api_tests/ || true - name: Set Github environment variables shell: bash -l {0} diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index a3e925bc1c..a2189defe7 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins import operator +from numbers import Integral import numpy as np @@ -799,6 +800,79 @@ def _nonzero_impl(ary): return res +def _validate_indices(inds, queue_list, usm_type_list): + """ + Utility for validating indices are usm_ndarray of integral dtype or Python + integers. At least one must be an array. + + For each array, the queue and usm type are appended to `queue_list` and + `usm_type_list`, respectively. + """ + any_usmarray = False + for ind in inds: + if isinstance(ind, dpt.usm_ndarray): + any_usmarray = True + if ind.dtype.kind not in "ui": + raise IndexError( + "arrays used as indices must be of integer (or boolean) " + "type" + ) + queue_list.append(ind.sycl_queue) + usm_type_list.append(ind.usm_type) + elif not isinstance(ind, Integral): + raise TypeError( + "all elements of `ind` expected to be usm_ndarrays " + f"or integers, found {type(ind)}" + ) + if not any_usmarray: + raise TypeError( + "at least one element of `inds` expected to be a usm_ndarray" + ) + return inds + + +def _prepare_indices_arrays(inds, q, usm_type): + """ + Utility taking a mix of usm_ndarray and possibly Python int scalar indices, + a queue (assumed to be common to arrays in inds), and a usm type. + + Python scalar integers are promoted to arrays on the provided queue and + with the provided usm type. All arrays are then promoted to a common + integral type (if possible) before being broadcast to a common shape. + """ + # scalar integers -> arrays + inds = tuple( + map( + lambda ind: ( + ind + if isinstance(ind, dpt.usm_ndarray) + else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q) + ), + inds, + ) + ) + + # promote to a common integral type if possible + ind_dt = dpt.result_type(*inds) + if ind_dt.kind not in "ui": + raise ValueError( + "cannot safely promote indices to an integer data type" + ) + inds = tuple( + map( + lambda ind: ( + ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt) + ), + inds, + ) + ) + + # broadcast + inds = dpt.broadcast_arrays(*inds) + + return inds + + def _take_multi_index(ary, inds, p, mode=0): if not isinstance(ary, dpt.usm_ndarray): raise TypeError( @@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0): ] if not isinstance(inds, (list, tuple)): inds = (inds,) - for ind in inds: - if not isinstance(ind, dpt.usm_ndarray): - raise TypeError("all elements of `ind` expected to be usm_ndarrays") - queues_.append(ind.sycl_queue) - usm_types_.append(ind.usm_type) - if ind.dtype.kind not in "ui": - raise IndexError( - "arrays used as indices must be of integer (or boolean) type" - ) + + _validate_indices(inds, queues_, usm_types_) res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is None: @@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0): "Use `usm_ndarray.to_device` method to migrate data to " "be associated with the same queue." ) + if len(inds) > 1: - ind_dt = dpt.result_type(*inds) - # ind arrays have been checked to be of integer dtype - if ind_dt.kind not in "ui": - raise ValueError( - "cannot safely promote indices to an integer data type" - ) - inds = tuple( - map( - lambda ind: ( - ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt) - ), - inds, - ) - ) - inds = dpt.broadcast_arrays(*inds) + inds = _prepare_indices_arrays(inds, exec_q, res_usm_type) + ind0 = inds[0] ary_sh = ary.shape p_end = p + len(inds) @@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0): ] if not isinstance(inds, (list, tuple)): inds = (inds,) - for ind in inds: - if not isinstance(ind, dpt.usm_ndarray): - raise TypeError("all elements of `ind` expected to be usm_ndarrays") - queues_.append(ind.sycl_queue) - usm_types_.append(ind.usm_type) - if ind.dtype.kind not in "ui": - raise IndexError( - "arrays used as indices must be of integer (or boolean) type" - ) + + _validate_indices(inds, queues_, usm_types_) + vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is not None: @@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0): "Use `usm_ndarray.to_device` method to migrate data to " "be associated with the same queue." ) + if len(inds) > 1: - ind_dt = dpt.result_type(*inds) - # ind arrays have been checked to be of integer dtype - if ind_dt.kind not in "ui": - raise ValueError( - "cannot safely promote indices to an integer data type" - ) - inds = tuple( - map( - lambda ind: ( - ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt) - ), - inds, - ) - ) - inds = dpt.broadcast_arrays(*inds) + inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type) + ind0 = inds[0] ary_sh = ary.shape p_end = p + len(inds) diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index 240e42a61d..bc7b13f7c7 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -15,6 +15,7 @@ # limitations under the License. import numbers +from operator import index from cpython.buffer cimport PyObject_CheckBuffer @@ -64,7 +65,7 @@ cdef bint _is_integral(object x) except *: return False if callable(getattr(x, "__index__", None)): try: - x.__index__() + index(x) except (TypeError, ValueError): return False return True @@ -136,7 +137,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): else: return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) elif _is_integral(ind): - ind = ind.__index__() + ind = index(ind) new_shape = shape[1:] new_strides = strides[1:] is_empty = any(sh_i == 0 for sh_i in new_shape) @@ -179,10 +180,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): if array_streak_started: array_streak_interrupted = True elif _is_integral(i): - explicit_index += 1 axes_referenced += 1 - if array_streak_started: - array_streak_interrupted = True + if array_streak_started and not array_streak_interrupted: + # integers converted to arrays in this case + array_count += 1 + else: + explicit_index += 1 elif isinstance(i, usm_ndarray): if not seen_arrays_yet: seen_arrays_yet = True @@ -229,6 +232,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): advanced_start_pos_set = False new_offset = offset is_empty = False + array_streak = False for i in range(len(ind)): ind_i = ind[i] if (ind_i is Ellipsis): @@ -239,9 +243,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): is_empty = True new_offset = offset k = k_new + if array_streak: + array_streak = False elif ind_i is None: new_shape.append(1) new_strides.append(0) + if array_streak: + array_streak = False elif isinstance(ind_i, slice): k_new = k + 1 sl_start, sl_stop, sl_step = ind_i.indices(shape[k]) @@ -255,26 +263,46 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): is_empty = True new_offset = offset k = k_new + if array_streak: + array_streak = False elif _is_boolean(ind_i): new_shape.append(1 if ind_i else 0) new_strides.append(0) + if array_streak: + array_streak = False elif _is_integral(ind_i): - ind_i = ind_i.__index__() - if 0 <= ind_i < shape[k]: + if array_streak: + if not isinstance(ind_i, usm_ndarray): + ind_i = index(ind_i) + # integer will be converted to an array, still raise if OOB + if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0): + raise IndexError( + ("Index {0} is out of range for " + "axes {1} with size {2}").format(ind_i, k, shape[k])) + new_advanced_ind.append(ind_i) k_new = k + 1 - if not is_empty: - new_offset = new_offset + ind_i * strides[k] - k = k_new - elif -shape[k] <= ind_i < 0: - k_new = k + 1 - if not is_empty: - new_offset = new_offset + (shape[k] + ind_i) * strides[k] + new_shape.extend(shape[k:k_new]) + new_strides.extend(strides[k:k_new]) k = k_new else: - raise IndexError( - ("Index {0} is out of range for " - "axes {1} with size {2}").format(ind_i, k, shape[k])) + ind_i = index(ind_i) + if 0 <= ind_i < shape[k]: + k_new = k + 1 + if not is_empty: + new_offset = new_offset + ind_i * strides[k] + k = k_new + elif -shape[k] <= ind_i < 0: + k_new = k + 1 + if not is_empty: + new_offset = new_offset + (shape[k] + ind_i) * strides[k] + k = k_new + else: + raise IndexError( + ("Index {0} is out of range for " + "axes {1} with size {2}").format(ind_i, k, shape[k])) elif isinstance(ind_i, usm_ndarray): + if not array_streak: + array_streak = True if not advanced_start_pos_set: new_advanced_start_pos = len(new_shape) advanced_start_pos_set = True @@ -287,8 +315,6 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): new_shape.extend(shape[k:k_new]) new_strides.extend(strides[k:k_new]) k = k_new - else: - raise IndexError new_shape.extend(shape[k:]) new_strides.extend(strides[k:]) new_shape_len += len(shape) - k diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 91f14660fb..a375bf93fe 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) ev = self_queue.submit_barrier() stream.submit_barrier(dependent_events=[ev]) - cdef class usm_ndarray: """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ offset=0, order="C", buffer_ctor_kwargs=dict(), \ @@ -962,6 +961,8 @@ cdef class usm_ndarray: return res from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index + + # if len(adv_ind == 1), the (only) element is always an array if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool: key_ = adv_ind[0] adv_ind_end_p = key_.ndim + adv_ind_start_p @@ -979,10 +980,10 @@ cdef class usm_ndarray: res.flags_ = _copy_writable(res.flags_, self.flags_) return res - if any(ind.dtype == dpt_bool for ind in adv_ind): + if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind): adv_ind_int = list() for ind in adv_ind: - if ind.dtype == dpt_bool: + if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool: adv_ind_int.extend(_nonzero_impl(ind)) else: adv_ind_int.append(ind) @@ -1433,10 +1434,10 @@ cdef class usm_ndarray: _place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p) return - if any(ind.dtype == dpt_bool for ind in adv_ind): + if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind): adv_ind_int = list() for ind in adv_ind: - if ind.dtype == dpt_bool: + if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool: adv_ind_int.extend(_nonzero_impl(ind)) else: adv_ind_int.append(ind) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 78501580a8..eeb97461fd 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -22,6 +22,7 @@ import dpctl import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti +from dpctl.tensor._copy_utils import _take_multi_index from dpctl.utils import ExecutionPlacementError from .helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -252,8 +253,14 @@ def test_advanced_slice5(): q = get_queue_or_skip() ii = dpt.asarray([1, 2], sycl_queue=q) x = _make_3d("i4", q) - with pytest.raises(IndexError): - x[ii, 0, ii] + y = x[ii, 0, ii] + assert isinstance(y, dpt.usm_ndarray) + # 0 broadcast to [0, 0] per array API + assert y.shape == ii.shape + assert _all_equal( + (x[ii[i], 0, ii[i]] for i in range(ii.shape[0])), + (y[i] for i in range(ii.shape[0])), + ) def test_advanced_slice6(): @@ -395,6 +402,44 @@ def test_advanced_slice13(): assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all() +def test_advanced_slice14(): + q = get_queue_or_skip() + ii = dpt.asarray([1, 2], sycl_queue=q) + x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5) + y = x[ii, 0, ii, 1, :] + assert isinstance(y, dpt.usm_ndarray) + # integers broadcast to ii.shape per array API + assert y.shape == ii.shape + x.shape[-1:] + assert _all_equal( + ( + x[ii[i], 0, ii[i], 1, k] + for i in range(ii.shape[0]) + for k in range(x.shape[-1]) + ), + (y[i, k] for i in range(ii.shape[0]) for k in range(x.shape[-1])), + ) + + +def test_advanced_slice15(): + q = get_queue_or_skip() + ii = dpt.asarray([1, 2], sycl_queue=q) + x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5) + # : cannot appear between two integral arrays + with pytest.raises(IndexError): + x[ii, 0, ii, :, ii] + + +def test_advanced_slice16(): + q = get_queue_or_skip() + ii = dpt.asarray(1, sycl_queue=q) + i0 = dpt.asarray(False, sycl_queue=q) + i1 = dpt.asarray(True, sycl_queue=q) + x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5) + y = x[ii, i0, ii, i1, :] + # TODO: add a shape check here when discrepancy with NumPy is investigated + assert isinstance(y, dpt.usm_ndarray) + + def test_boolean_indexing_validation(): get_queue_or_skip() x = dpt.zeros(10, dtype="i4") @@ -1956,3 +2001,17 @@ def test_take_out_errors(): out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2) with pytest.raises(dpctl.utils.ExecutionPlacementError): dpt.take(x, ind, out=out_bad_q) + + +def test_getitem_impl_fn_invalid_inp(): + get_queue_or_skip() + + x = dpt.ones((10, 10), dtype="i4") + + bad_ind_type = (dpt.ones((), dtype="i4"), 2.0) + with pytest.raises(TypeError): + _take_multi_index(x, bad_ind_type, 0, 0) + + no_array_inds = (2, 3) + with pytest.raises(TypeError): + _take_multi_index(x, no_array_inds, 0, 0)