From 89ec2136bd23b332e3bf0bab6496a604a9aea6fa Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 13 Jan 2025 17:59:10 -0800 Subject: [PATCH 1/2] Rename Python bindings for integer indexing Aligns with most non-constructors --- .../source/integer_advanced_indexing.cpp | 28 +++++++++---------- .../source/integer_advanced_indexing.hpp | 28 +++++++++---------- .../tensor/libtensor/source/tensor_ctors.cpp | 8 +++--- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index 5eb54bbe70..47fce829e3 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -244,13 +244,13 @@ std::vector parse_py_ind(const sycl::queue &q, } std::pair -usm_ndarray_take(const dpctl::tensor::usm_ndarray &src, - const py::object &py_ind, - const dpctl::tensor::usm_ndarray &dst, - int axis_start, - std::uint8_t mode, - sycl::queue &exec_q, - const std::vector &depends) +py_take(const dpctl::tensor::usm_ndarray &src, + const py::object &py_ind, + const dpctl::tensor::usm_ndarray &dst, + int axis_start, + std::uint8_t mode, + sycl::queue &exec_q, + const std::vector &depends) { std::vector ind = parse_py_ind(exec_q, py_ind); @@ -515,13 +515,13 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src, } std::pair -usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst, - const py::object &py_ind, - const dpctl::tensor::usm_ndarray &val, - int axis_start, - std::uint8_t mode, - sycl::queue &exec_q, - const std::vector &depends) +py_put(const dpctl::tensor::usm_ndarray &dst, + const py::object &py_ind, + const dpctl::tensor::usm_ndarray &val, + int axis_start, + std::uint8_t mode, + sycl::queue &exec_q, + const std::vector &depends) { std::vector ind = parse_py_ind(exec_q, py_ind); int k = ind.size(); diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp index 10555b3dad..8627a7f9a2 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp @@ -39,22 +39,22 @@ namespace py_internal { extern std::pair -usm_ndarray_take(const dpctl::tensor::usm_ndarray &, - const py::object &, - const dpctl::tensor::usm_ndarray &, - int, - std::uint8_t, - sycl::queue &, - const std::vector & = {}); +py_take(const dpctl::tensor::usm_ndarray &, + const py::object &, + const dpctl::tensor::usm_ndarray &, + int, + std::uint8_t, + sycl::queue &, + const std::vector & = {}); extern std::pair -usm_ndarray_put(const dpctl::tensor::usm_ndarray &, - const py::object &, - const dpctl::tensor::usm_ndarray &, - int, - std::uint8_t, - sycl::queue &, - const std::vector & = {}); +py_put(const dpctl::tensor::usm_ndarray &, + const py::object &, + const dpctl::tensor::usm_ndarray &, + int, + std::uint8_t, + sycl::queue &, + const std::vector & = {}); extern void init_advanced_indexing_dispatch_tables(void); diff --git a/dpctl/tensor/libtensor/source/tensor_ctors.cpp b/dpctl/tensor/libtensor/source/tensor_ctors.cpp index 2dccc4e359..c7894c0f36 100644 --- a/dpctl/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpctl/tensor/libtensor/source/tensor_ctors.cpp @@ -101,8 +101,8 @@ using dpctl::tensor::py_internal::usm_ndarray_full; using dpctl::tensor::py_internal::usm_ndarray_zeros; /* ============== Advanced Indexing ============= */ -using dpctl::tensor::py_internal::usm_ndarray_put; -using dpctl::tensor::py_internal::usm_ndarray_take; +using dpctl::tensor::py_internal::py_put; +using dpctl::tensor::py_internal::py_take; using dpctl::tensor::py_internal::py_extract; using dpctl::tensor::py_internal::py_mask_positions; @@ -324,7 +324,7 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_take", &usm_ndarray_take, + m.def("_take", &py_take, "Takes elements at usm_ndarray indices `ind` and axes starting " "at axis `axis_start` from array `src` and copies them " "into usm_ndarray `dst` synchronously." @@ -333,7 +333,7 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("mode"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_put", &usm_ndarray_put, + m.def("_put", &py_put, "Puts elements at usm_ndarray indices `ind` and axes starting " "at axis `axis_start` into array `dst` from " "usm_ndarray `val` synchronously." From d409117d8d2b91992a5b797c3000bdf0c101d6a1 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 21 Feb 2025 17:22:40 -0800 Subject: [PATCH 2/2] Change integer indexing mode dispatching --- .../source/integer_advanced_indexing.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index 47fce829e3..6cf6067938 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -46,10 +46,6 @@ #include "integer_advanced_indexing.hpp" -#define INDEXING_MODES 2 -#define WRAP_MODE 0 -#define CLIP_MODE 1 - namespace dpctl { namespace tensor @@ -62,11 +58,15 @@ namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::kernels::indexing::put_fn_ptr_t; using dpctl::tensor::kernels::indexing::take_fn_ptr_t; -static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types] - [td_ns::num_types]; +static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static put_fn_ptr_t put_wrap_dispatch_table[td_ns::num_types][td_ns::num_types]; -static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types] - [td_ns::num_types]; +static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types]; namespace py = pybind11; @@ -486,7 +486,8 @@ py_take(const dpctl::tensor::usm_ndarray &src, std::end(pack_deps)); all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); - auto fn = take_dispatch_table[mode][src_type_id][ind_type_id]; + auto fn = mode ? take_clip_dispatch_table[src_type_id][ind_type_id] + : take_wrap_dispatch_table[src_type_id][ind_type_id]; if (fn == nullptr) { sycl::event::wait(host_task_events); @@ -755,7 +756,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst, std::end(pack_deps)); all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); - auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id]; + auto fn = mode ? put_clip_dispatch_table[dst_type_id][ind_type_id] + : put_wrap_dispatch_table[dst_type_id][ind_type_id]; if (fn == nullptr) { sycl::event::wait(host_task_events); @@ -790,20 +792,20 @@ void init_advanced_indexing_dispatch_tables(void) using dpctl::tensor::kernels::indexing::TakeClipFactory; DispatchTableBuilder dtb_takeclip; - dtb_takeclip.populate_dispatch_table(take_dispatch_table[CLIP_MODE]); + dtb_takeclip.populate_dispatch_table(take_clip_dispatch_table); using dpctl::tensor::kernels::indexing::TakeWrapFactory; DispatchTableBuilder dtb_takewrap; - dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]); + dtb_takewrap.populate_dispatch_table(take_wrap_dispatch_table); using dpctl::tensor::kernels::indexing::PutClipFactory; DispatchTableBuilder dtb_putclip; - dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]); + dtb_putclip.populate_dispatch_table(put_clip_dispatch_table); using dpctl::tensor::kernels::indexing::PutWrapFactory; DispatchTableBuilder dtb_putwrap; - dtb_putwrap.populate_dispatch_table(put_dispatch_table[WRAP_MODE]); + dtb_putwrap.populate_dispatch_table(put_wrap_dispatch_table); } } // namespace py_internal