diff --git a/dpnp/backend/extensions/statistics/common.hpp b/dpnp/backend/extensions/common/ext/common.hpp similarity index 98% rename from dpnp/backend/extensions/statistics/common.hpp rename to dpnp/backend/extensions/common/ext/common.hpp index 39f1535112fa..9a45e21a4e7a 100644 --- a/dpnp/backend/extensions/statistics/common.hpp +++ b/dpnp/backend/extensions/common/ext/common.hpp @@ -35,7 +35,7 @@ namespace type_utils = dpctl::tensor::type_utils; -namespace statistics::common +namespace ext::common { template @@ -185,4 +185,6 @@ sycl::nd_range<1> // headers of dpctl. pybind11::dtype dtype_from_typenum(int dst_typenum); -} // namespace statistics::common +} // namespace ext::common + +#include "ext/details/common_internal.hpp" diff --git a/dpnp/backend/extensions/statistics/common.cpp b/dpnp/backend/extensions/common/ext/details/common_internal.hpp similarity index 87% rename from dpnp/backend/extensions/statistics/common.cpp rename to dpnp/backend/extensions/common/ext/details/common_internal.hpp index 379817435b00..c0c3f444c2e3 100644 --- a/dpnp/backend/extensions/statistics/common.cpp +++ b/dpnp/backend/extensions/common/ext/details/common_internal.hpp @@ -23,15 +23,15 @@ // THE POSSIBILITY OF SUCH DAMAGE. //***************************************************************************** -#include "common.hpp" +#include "ext/common.hpp" #include "utils/type_dispatch.hpp" #include namespace dpctl_td_ns = dpctl::tensor::type_dispatch; -namespace statistics::common +namespace ext::common { -size_t get_max_local_size(const sycl::device &device) +inline size_t get_max_local_size(const sycl::device &device) { constexpr const int default_max_cpu_local_size = 256; constexpr const int default_max_gpu_local_size = 0; @@ -40,9 +40,9 @@ size_t get_max_local_size(const sycl::device &device) default_max_gpu_local_size); } -size_t get_max_local_size(const sycl::device &device, - int cpu_local_size_limit, - int gpu_local_size_limit) +inline size_t get_max_local_size(const sycl::device &device, + int cpu_local_size_limit, + int gpu_local_size_limit) { int max_work_group_size = device.get_info(); @@ -56,7 +56,7 @@ size_t get_max_local_size(const sycl::device &device, return max_work_group_size; } -sycl::nd_range<1> +inline sycl::nd_range<1> make_ndrange(size_t global_size, size_t local_range, size_t work_per_item) { return make_ndrange(sycl::range<1>(global_size), @@ -64,7 +64,7 @@ sycl::nd_range<1> sycl::range<1>(work_per_item)); } -size_t get_local_mem_size_in_bytes(const sycl::device &device) +inline size_t get_local_mem_size_in_bytes(const sycl::device &device) { // Reserving 1kb for runtime needs constexpr const size_t reserve = 1024; @@ -72,14 +72,15 @@ size_t get_local_mem_size_in_bytes(const sycl::device &device) return get_local_mem_size_in_bytes(device, reserve); } -size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve) +inline size_t get_local_mem_size_in_bytes(const sycl::device &device, + size_t reserve) { size_t local_mem_size = device.get_info(); return local_mem_size - reserve; } -pybind11::dtype dtype_from_typenum(int dst_typenum) +inline pybind11::dtype dtype_from_typenum(int dst_typenum) { dpctl_td_ns::typenum_t dst_typenum_t = static_cast(dst_typenum); @@ -117,4 +118,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum) } } -} // namespace statistics::common +} // namespace ext::common diff --git a/dpnp/backend/extensions/statistics/validation_utils.cpp b/dpnp/backend/extensions/common/ext/details/validation_utils_internal.hpp similarity index 74% rename from dpnp/backend/extensions/statistics/validation_utils.cpp rename to dpnp/backend/extensions/common/ext/details/validation_utils_internal.hpp index 882e288e0271..d5a65d3b9961 100644 --- a/dpnp/backend/extensions/statistics/validation_utils.cpp +++ b/dpnp/backend/extensions/common/ext/details/validation_utils_internal.hpp @@ -23,17 +23,13 @@ // THE POSSIBILITY OF SUCH DAMAGE. //***************************************************************************** -#include "validation_utils.hpp" +#include "ext/validation_utils.hpp" #include "utils/memory_overlap.hpp" -using statistics::validation::array_names; -using statistics::validation::array_ptr; - -namespace +namespace ext::validation { - -sycl::queue get_queue(const std::vector &inputs, - const std::vector &outputs) +inline sycl::queue get_queue(const std::vector &inputs, + const std::vector &outputs) { auto it = std::find_if(inputs.cbegin(), inputs.cend(), [](const array_ptr &arr) { return arr != nullptr; }); @@ -51,11 +47,8 @@ sycl::queue get_queue(const std::vector &inputs, throw py::value_error("No input or output arrays found"); } -} // namespace -namespace statistics::validation -{ -std::string name_of(const array_ptr &arr, const array_names &names) +inline std::string name_of(const array_ptr &arr, const array_names &names) { auto name_it = names.find(arr); assert(name_it != names.end()); @@ -66,8 +59,8 @@ std::string name_of(const array_ptr &arr, const array_names &names) return "'unknown'"; } -void check_writable(const std::vector &arrays, - const array_names &names) +inline void check_writable(const std::vector &arrays, + const array_names &names) { for (const auto &arr : arrays) { if (arr != nullptr && !arr->is_writable()) { @@ -77,8 +70,8 @@ void check_writable(const std::vector &arrays, } } -void check_c_contig(const std::vector &arrays, - const array_names &names) +inline void check_c_contig(const std::vector &arrays, + const array_names &names) { for (const auto &arr : arrays) { if (arr != nullptr && !arr->is_c_contiguous()) { @@ -88,9 +81,9 @@ void check_c_contig(const std::vector &arrays, } } -void check_queue(const std::vector &arrays, - const array_names &names, - const sycl::queue &exec_q) +inline void check_queue(const std::vector &arrays, + const array_names &names, + const sycl::queue &exec_q) { auto unequal_queue = std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) { @@ -104,9 +97,9 @@ void check_queue(const std::vector &arrays, } } -void check_no_overlap(const array_ptr &input, - const array_ptr &output, - const array_names &names) +inline void check_no_overlap(const array_ptr &input, + const array_ptr &output, + const array_names &names) { if (input == nullptr || output == nullptr) { return; @@ -121,9 +114,9 @@ void check_no_overlap(const array_ptr &input, } } -void check_no_overlap(const std::vector &inputs, - const std::vector &outputs, - const array_names &names) +inline void check_no_overlap(const std::vector &inputs, + const std::vector &outputs, + const array_names &names) { for (const auto &input : inputs) { for (const auto &output : outputs) { @@ -132,9 +125,9 @@ void check_no_overlap(const std::vector &inputs, } } -void check_num_dims(const array_ptr &arr, - const size_t ndim, - const array_names &names) +inline void check_num_dims(const array_ptr &arr, + const size_t ndim, + const array_names &names) { size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0; if (arr != nullptr && arr_n_dim != ndim) { @@ -144,9 +137,9 @@ void check_num_dims(const array_ptr &arr, } } -void check_max_dims(const array_ptr &arr, - const size_t max_ndim, - const array_names &names) +inline void check_max_dims(const array_ptr &arr, + const size_t max_ndim, + const array_names &names) { size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0; if (arr != nullptr && arr_n_dim > max_ndim) { @@ -157,9 +150,9 @@ void check_max_dims(const array_ptr &arr, } } -void check_size_at_least(const array_ptr &arr, - const size_t size, - const array_names &names) +inline void check_size_at_least(const array_ptr &arr, + const size_t size, + const array_names &names) { size_t arr_size = arr != nullptr ? arr->get_size() : 0; if (arr != nullptr && arr_size < size) { @@ -170,9 +163,9 @@ void check_size_at_least(const array_ptr &arr, } } -void common_checks(const std::vector &inputs, - const std::vector &outputs, - const array_names &names) +inline void common_checks(const std::vector &inputs, + const std::vector &outputs, + const array_names &names) { check_writable(outputs, names); @@ -187,4 +180,4 @@ void common_checks(const std::vector &inputs, check_no_overlap(inputs, outputs, names); } -} // namespace statistics::validation +} // namespace ext::validation diff --git a/dpnp/backend/extensions/statistics/dispatch_table.hpp b/dpnp/backend/extensions/common/ext/dispatch_table.hpp similarity index 99% rename from dpnp/backend/extensions/statistics/dispatch_table.hpp rename to dpnp/backend/extensions/common/ext/dispatch_table.hpp index 1e58a5b917f1..64cf994d52f0 100644 --- a/dpnp/backend/extensions/statistics/dispatch_table.hpp +++ b/dpnp/backend/extensions/common/ext/dispatch_table.hpp @@ -34,12 +34,12 @@ #include #include -#include "common.hpp" +#include "ext/common.hpp" namespace dpctl_td_ns = dpctl::tensor::type_dispatch; namespace py = pybind11; -namespace statistics::common +namespace ext::common { template struct one_of @@ -383,4 +383,4 @@ class DispatchTable2 Table2 table; }; -} // namespace statistics::common +} // namespace ext::common diff --git a/dpnp/backend/extensions/statistics/validation_utils.hpp b/dpnp/backend/extensions/common/ext/validation_utils.hpp similarity index 96% rename from dpnp/backend/extensions/statistics/validation_utils.hpp rename to dpnp/backend/extensions/common/ext/validation_utils.hpp index f8e1487c1d09..53b71b07e427 100644 --- a/dpnp/backend/extensions/statistics/validation_utils.hpp +++ b/dpnp/backend/extensions/common/ext/validation_utils.hpp @@ -31,7 +31,7 @@ #include "dpctl4pybind11.hpp" -namespace statistics::validation +namespace ext::validation { using array_ptr = const dpctl::tensor::usm_ndarray *; using array_names = std::unordered_map; @@ -67,4 +67,6 @@ void check_size_at_least(const array_ptr &arr, void common_checks(const std::vector &inputs, const std::vector &outputs, const array_names &names); -} // namespace statistics::validation +} // namespace ext::validation + +#include "ext/details/validation_utils_internal.hpp" diff --git a/dpnp/backend/extensions/statistics/CMakeLists.txt b/dpnp/backend/extensions/statistics/CMakeLists.txt index b11714849ffd..2a5467bff382 100644 --- a/dpnp/backend/extensions/statistics/CMakeLists.txt +++ b/dpnp/backend/extensions/statistics/CMakeLists.txt @@ -27,14 +27,12 @@ set(python_module_name _statistics_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) @@ -66,6 +64,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common) target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR}) target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) diff --git a/dpnp/backend/extensions/statistics/bincount.cpp b/dpnp/backend/extensions/statistics/bincount.cpp index 79dafc6ec5d9..81912291f4c5 100644 --- a/dpnp/backend/extensions/statistics/bincount.cpp +++ b/dpnp/backend/extensions/statistics/bincount.cpp @@ -34,7 +34,7 @@ using dpctl::tensor::usm_ndarray; using namespace statistics::histogram; -using namespace statistics::common; +using namespace ext::common; namespace { diff --git a/dpnp/backend/extensions/statistics/bincount.hpp b/dpnp/backend/extensions/statistics/bincount.hpp index ec65c5399e7e..c2365f35fcfe 100644 --- a/dpnp/backend/extensions/statistics/bincount.hpp +++ b/dpnp/backend/extensions/statistics/bincount.hpp @@ -28,8 +28,8 @@ #include #include -#include "dispatch_table.hpp" #include "dpctl4pybind11.hpp" +#include "ext/dispatch_table.hpp" namespace dpctl_td_ns = dpctl::tensor::type_dispatch; @@ -46,7 +46,7 @@ struct Bincount const size_t, const std::vector &); - common::DispatchTable2 dispatch_table; + ext::common::DispatchTable2 dispatch_table; Bincount(); diff --git a/dpnp/backend/extensions/statistics/histogram.cpp b/dpnp/backend/extensions/statistics/histogram.cpp index 1ee0d818a620..5e05a44858f2 100644 --- a/dpnp/backend/extensions/statistics/histogram.cpp +++ b/dpnp/backend/extensions/statistics/histogram.cpp @@ -43,7 +43,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::usm_ndarray; using namespace statistics::histogram; -using namespace statistics::common; +using namespace ext::common; namespace { diff --git a/dpnp/backend/extensions/statistics/histogram.hpp b/dpnp/backend/extensions/statistics/histogram.hpp index f64744fbcbf2..4daae8d4214c 100644 --- a/dpnp/backend/extensions/statistics/histogram.hpp +++ b/dpnp/backend/extensions/statistics/histogram.hpp @@ -28,8 +28,8 @@ #include #include -#include "dispatch_table.hpp" #include "dpctl4pybind11.hpp" +#include "ext/dispatch_table.hpp" namespace statistics::histogram { @@ -44,7 +44,7 @@ struct Histogram const size_t, const std::vector &); - common::DispatchTable2 dispatch_table; + ext::common::DispatchTable2 dispatch_table; Histogram(); diff --git a/dpnp/backend/extensions/statistics/histogram_common.cpp b/dpnp/backend/extensions/statistics/histogram_common.cpp index a6b35f091880..1840dad6358c 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.cpp +++ b/dpnp/backend/extensions/statistics/histogram_common.cpp @@ -35,26 +35,24 @@ #include "histogram_common.hpp" -#include "validation_utils.hpp" +#include "ext/validation_utils.hpp" namespace dpctl_td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::usm_ndarray; using dpctl_td_ns::typenum_t; -namespace statistics -{ -using common::CeilDiv; +using ext::common::CeilDiv; -using validation::array_names; -using validation::array_ptr; +using ext::validation::array_names; +using ext::validation::array_ptr; -using validation::check_max_dims; -using validation::check_num_dims; -using validation::check_size_at_least; -using validation::common_checks; -using validation::name_of; +using ext::validation::check_max_dims; +using ext::validation::check_num_dims; +using ext::validation::check_size_at_least; +using ext::validation::common_checks; +using ext::validation::name_of; -namespace histogram +namespace statistics::histogram { void validate(const usm_ndarray &sample, @@ -199,5 +197,4 @@ uint32_t get_local_hist_copies_count(uint32_t loc_mem_size_in_items, return local_hist_count; } -} // namespace histogram -} // namespace statistics +} // namespace statistics::histogram diff --git a/dpnp/backend/extensions/statistics/histogram_common.hpp b/dpnp/backend/extensions/statistics/histogram_common.hpp index 58127d708a2a..e826e4a9a4a1 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.hpp +++ b/dpnp/backend/extensions/statistics/histogram_common.hpp @@ -27,7 +27,7 @@ #include -#include "common.hpp" +#include "ext/common.hpp" namespace dpctl::tensor { @@ -36,13 +36,11 @@ class usm_ndarray; using dpctl::tensor::usm_ndarray; -namespace statistics -{ -using common::AtomicOp; -using common::IsNan; -using common::Less; +using ext::common::AtomicOp; +using ext::common::IsNan; +using ext::common::Less; -namespace histogram +namespace statistics::histogram { template @@ -369,5 +367,4 @@ uint32_t get_local_hist_copies_count(uint32_t loc_mem_size_in_items, uint32_t local_size, uint32_t hist_size_in_items); -} // namespace histogram -} // namespace statistics +} // namespace statistics::histogram diff --git a/dpnp/backend/extensions/statistics/histogramdd.cpp b/dpnp/backend/extensions/statistics/histogramdd.cpp index f3c1af675143..9567eb5c1927 100644 --- a/dpnp/backend/extensions/statistics/histogramdd.cpp +++ b/dpnp/backend/extensions/statistics/histogramdd.cpp @@ -34,7 +34,7 @@ using dpctl::tensor::usm_ndarray; using namespace statistics::histogram; -using namespace statistics::common; +using namespace ext::common; namespace { diff --git a/dpnp/backend/extensions/statistics/histogramdd.hpp b/dpnp/backend/extensions/statistics/histogramdd.hpp index 2a8ba7798d55..87c0e906ff84 100644 --- a/dpnp/backend/extensions/statistics/histogramdd.hpp +++ b/dpnp/backend/extensions/statistics/histogramdd.hpp @@ -28,12 +28,10 @@ #include #include -#include "dispatch_table.hpp" #include "dpctl4pybind11.hpp" +#include "ext/dispatch_table.hpp" -namespace statistics -{ -namespace histogram +namespace statistics::histogram { struct Histogramdd { @@ -49,7 +47,7 @@ struct Histogramdd const size_t, const std::vector &); - common::DispatchTable2 dispatch_table; + ext::common::DispatchTable2 dispatch_table; Histogramdd(); @@ -63,5 +61,4 @@ struct Histogramdd }; void populate_histogramdd(py::module_ m); -} // namespace histogram -} // namespace statistics +} // namespace statistics::histogram diff --git a/dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp b/dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp index 02d2ead3f64f..8a647c02366d 100644 --- a/dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp +++ b/dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp @@ -34,7 +34,7 @@ #include "dpctl4pybind11.hpp" #include "utils/type_dispatch.hpp" -#include "common.hpp" +#include "ext/common.hpp" #include "sliding_dot_product1d.hpp" #include "sliding_window1d.hpp" @@ -44,7 +44,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::usm_ndarray; using namespace statistics::sliding_window1d; -using namespace statistics::common; +using namespace ext::common; namespace { diff --git a/dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp b/dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp index a4acf26cb3fe..25fd5ed2e88f 100644 --- a/dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp +++ b/dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp @@ -25,7 +25,7 @@ #pragma once -#include "dispatch_table.hpp" +#include "ext/dispatch_table.hpp" #include #include @@ -43,7 +43,7 @@ struct SlidingDotProduct1d const size_t, const std::vector &); - common::DispatchTable dispatch_table; + ext::common::DispatchTable dispatch_table; SlidingDotProduct1d(); diff --git a/dpnp/backend/extensions/statistics/sliding_window1d.cpp b/dpnp/backend/extensions/statistics/sliding_window1d.cpp index 59cf837e6eaa..6d6c0b171d5c 100644 --- a/dpnp/backend/extensions/statistics/sliding_window1d.cpp +++ b/dpnp/backend/extensions/statistics/sliding_window1d.cpp @@ -30,22 +30,20 @@ #include "utils/type_dispatch.hpp" #include +#include "ext/validation_utils.hpp" #include "sliding_window1d.hpp" -#include "validation_utils.hpp" namespace dpctl_td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::usm_ndarray; using dpctl_td_ns::typenum_t; -namespace statistics -{ -using validation::array_names; -using validation::array_ptr; -using validation::check_num_dims; -using validation::common_checks; -using validation::name_of; +using ext::validation::array_names; +using ext::validation::array_ptr; +using ext::validation::check_num_dims; +using ext::validation::common_checks; +using ext::validation::name_of; -namespace sliding_window1d +namespace statistics::sliding_window1d { void validate(const usm_ndarray &a, @@ -89,5 +87,4 @@ void validate(const usm_ndarray &a, } } -} // namespace sliding_window1d -} // namespace statistics +} // namespace statistics::sliding_window1d diff --git a/dpnp/backend/extensions/statistics/sliding_window1d.hpp b/dpnp/backend/extensions/statistics/sliding_window1d.hpp index 278fdaaa83c0..11352fc3a91b 100644 --- a/dpnp/backend/extensions/statistics/sliding_window1d.hpp +++ b/dpnp/backend/extensions/statistics/sliding_window1d.hpp @@ -31,16 +31,14 @@ #include -#include "common.hpp" +#include "ext/common.hpp" using dpctl::tensor::usm_ndarray; -namespace statistics -{ -using common::Align; -using common::CeilDiv; +using ext::common::Align; +using ext::common::CeilDiv; -namespace sliding_window1d +namespace statistics::sliding_window1d { template @@ -668,5 +666,4 @@ void validate(const usm_ndarray &a, const usm_ndarray &out, const size_t l_pad, const size_t r_pad); -} // namespace sliding_window1d -} // namespace statistics +} // namespace statistics::sliding_window1d