Skip to content

Commit 89ec213

Browse files
committed
Rename Python bindings for integer indexing
Aligns with most non-constructors
1 parent bc7a739 commit 89ec213

File tree

3 files changed

+32
-32
lines changed

3 files changed

+32
-32
lines changed

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,13 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_ind(const sycl::queue &q,
244244
}
245245

246246
std::pair<sycl::event, sycl::event>
247-
usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
248-
const py::object &py_ind,
249-
const dpctl::tensor::usm_ndarray &dst,
250-
int axis_start,
251-
std::uint8_t mode,
252-
sycl::queue &exec_q,
253-
const std::vector<sycl::event> &depends)
247+
py_take(const dpctl::tensor::usm_ndarray &src,
248+
const py::object &py_ind,
249+
const dpctl::tensor::usm_ndarray &dst,
250+
int axis_start,
251+
std::uint8_t mode,
252+
sycl::queue &exec_q,
253+
const std::vector<sycl::event> &depends)
254254
{
255255
std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind(exec_q, py_ind);
256256

@@ -515,13 +515,13 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
515515
}
516516

517517
std::pair<sycl::event, sycl::event>
518-
usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
519-
const py::object &py_ind,
520-
const dpctl::tensor::usm_ndarray &val,
521-
int axis_start,
522-
std::uint8_t mode,
523-
sycl::queue &exec_q,
524-
const std::vector<sycl::event> &depends)
518+
py_put(const dpctl::tensor::usm_ndarray &dst,
519+
const py::object &py_ind,
520+
const dpctl::tensor::usm_ndarray &val,
521+
int axis_start,
522+
std::uint8_t mode,
523+
sycl::queue &exec_q,
524+
const std::vector<sycl::event> &depends)
525525
{
526526
std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind(exec_q, py_ind);
527527
int k = ind.size();

dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,22 @@ namespace py_internal
3939
{
4040

4141
extern std::pair<sycl::event, sycl::event>
42-
usm_ndarray_take(const dpctl::tensor::usm_ndarray &,
43-
const py::object &,
44-
const dpctl::tensor::usm_ndarray &,
45-
int,
46-
std::uint8_t,
47-
sycl::queue &,
48-
const std::vector<sycl::event> & = {});
42+
py_take(const dpctl::tensor::usm_ndarray &,
43+
const py::object &,
44+
const dpctl::tensor::usm_ndarray &,
45+
int,
46+
std::uint8_t,
47+
sycl::queue &,
48+
const std::vector<sycl::event> & = {});
4949

5050
extern std::pair<sycl::event, sycl::event>
51-
usm_ndarray_put(const dpctl::tensor::usm_ndarray &,
52-
const py::object &,
53-
const dpctl::tensor::usm_ndarray &,
54-
int,
55-
std::uint8_t,
56-
sycl::queue &,
57-
const std::vector<sycl::event> & = {});
51+
py_put(const dpctl::tensor::usm_ndarray &,
52+
const py::object &,
53+
const dpctl::tensor::usm_ndarray &,
54+
int,
55+
std::uint8_t,
56+
sycl::queue &,
57+
const std::vector<sycl::event> & = {});
5858

5959
extern void init_advanced_indexing_dispatch_tables(void);
6060

dpctl/tensor/libtensor/source/tensor_ctors.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ using dpctl::tensor::py_internal::usm_ndarray_full;
101101
using dpctl::tensor::py_internal::usm_ndarray_zeros;
102102

103103
/* ============== Advanced Indexing ============= */
104-
using dpctl::tensor::py_internal::usm_ndarray_put;
105-
using dpctl::tensor::py_internal::usm_ndarray_take;
104+
using dpctl::tensor::py_internal::py_put;
105+
using dpctl::tensor::py_internal::py_take;
106106

107107
using dpctl::tensor::py_internal::py_extract;
108108
using dpctl::tensor::py_internal::py_mask_positions;
@@ -324,7 +324,7 @@ PYBIND11_MODULE(_tensor_impl, m)
324324
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
325325
py::arg("depends") = py::list());
326326

327-
m.def("_take", &usm_ndarray_take,
327+
m.def("_take", &py_take,
328328
"Takes elements at usm_ndarray indices `ind` and axes starting "
329329
"at axis `axis_start` from array `src` and copies them "
330330
"into usm_ndarray `dst` synchronously."
@@ -333,7 +333,7 @@ PYBIND11_MODULE(_tensor_impl, m)
333333
py::arg("mode"), py::arg("sycl_queue"),
334334
py::arg("depends") = py::list());
335335

336-
m.def("_put", &usm_ndarray_put,
336+
m.def("_put", &py_put,
337337
"Puts elements at usm_ndarray indices `ind` and axes starting "
338338
"at axis `axis_start` into array `dst` from "
339339
"usm_ndarray `val` synchronously."

0 commit comments

Comments
 (0)