Skip to content

Commit bc2a2d3

Browse files
Move statistic utils to common place as header only
1 parent ec65f73 commit bc2a2d3

18 files changed

+79
-79
lines changed

dpnp/backend/extensions/statistics/common.hpp dpnp/backend/extensions/common/ext/common.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
namespace type_utils = dpctl::tensor::type_utils;
3737

38-
namespace statistics::common
38+
namespace ext::common
3939
{
4040

4141
template <typename N, typename D>
@@ -185,4 +185,6 @@ sycl::nd_range<1>
185185
// headers of dpctl.
186186
pybind11::dtype dtype_from_typenum(int dst_typenum);
187187

188-
} // namespace statistics::common
188+
} // namespace ext::common
189+
190+
#include "ext/details/common_internal.hpp"

dpnp/backend/extensions/statistics/common.cpp dpnp/backend/extensions/common/ext/details/common_internal.hpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include "common.hpp"
26+
#include "ext/common.hpp"
2727
#include "utils/type_dispatch.hpp"
2828
#include <pybind11/pybind11.h>
2929

3030
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3131

32-
namespace statistics::common
32+
namespace ext::common
3333
{
34-
size_t get_max_local_size(const sycl::device &device)
34+
inline size_t get_max_local_size(const sycl::device &device)
3535
{
3636
constexpr const int default_max_cpu_local_size = 256;
3737
constexpr const int default_max_gpu_local_size = 0;
@@ -40,7 +40,7 @@ size_t get_max_local_size(const sycl::device &device)
4040
default_max_gpu_local_size);
4141
}
4242

43-
size_t get_max_local_size(const sycl::device &device,
43+
inline size_t get_max_local_size(const sycl::device &device,
4444
int cpu_local_size_limit,
4545
int gpu_local_size_limit)
4646
{
@@ -56,30 +56,30 @@ size_t get_max_local_size(const sycl::device &device,
5656
return max_work_group_size;
5757
}
5858

59-
sycl::nd_range<1>
59+
inline sycl::nd_range<1>
6060
make_ndrange(size_t global_size, size_t local_range, size_t work_per_item)
6161
{
6262
return make_ndrange(sycl::range<1>(global_size),
6363
sycl::range<1>(local_range),
6464
sycl::range<1>(work_per_item));
6565
}
6666

67-
size_t get_local_mem_size_in_bytes(const sycl::device &device)
67+
inline size_t get_local_mem_size_in_bytes(const sycl::device &device)
6868
{
6969
// Reserving 1kb for runtime needs
7070
constexpr const size_t reserve = 1024;
7171

7272
return get_local_mem_size_in_bytes(device, reserve);
7373
}
7474

75-
size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve)
75+
inline size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve)
7676
{
7777
size_t local_mem_size =
7878
device.get_info<sycl::info::device::local_mem_size>();
7979
return local_mem_size - reserve;
8080
}
8181

82-
pybind11::dtype dtype_from_typenum(int dst_typenum)
82+
inline pybind11::dtype dtype_from_typenum(int dst_typenum)
8383
{
8484
dpctl_td_ns::typenum_t dst_typenum_t =
8585
static_cast<dpctl_td_ns::typenum_t>(dst_typenum);
@@ -117,4 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
117117
}
118118
}
119119

120-
} // namespace statistics::common
120+
} // namespace ext::common

dpnp/backend/extensions/statistics/validation_utils.cpp dpnp/backend/extensions/common/ext/details/validation_utils_internal.hpp

+17-24
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,32 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include "validation_utils.hpp"
26+
#include "ext/validation_utils.hpp"
2727
#include "utils/memory_overlap.hpp"
2828

29-
using statistics::validation::array_names;
30-
using statistics::validation::array_ptr;
31-
32-
namespace
29+
namespace ext::validation
3330
{
34-
35-
sycl::queue get_queue(const std::vector<array_ptr> &inputs,
36-
const std::vector<array_ptr> &outputs)
31+
inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
32+
const std::vector<array_ptr> &outputs)
3733
{
3834
auto it = std::find_if(inputs.cbegin(), inputs.cend(),
39-
[](const array_ptr &arr) { return arr != nullptr; });
35+
[](const array_ptr &arr) { return arr != nullptr; });
4036

4137
if (it != inputs.cend()) {
4238
return (*it)->get_queue();
4339
}
4440

4541
it = std::find_if(outputs.cbegin(), outputs.cend(),
46-
[](const array_ptr &arr) { return arr != nullptr; });
42+
[](const array_ptr &arr) { return arr != nullptr; });
4743

4844
if (it != outputs.cend()) {
4945
return (*it)->get_queue();
5046
}
5147

5248
throw py::value_error("No input or output arrays found");
5349
}
54-
} // namespace
5550

56-
namespace statistics::validation
57-
{
58-
std::string name_of(const array_ptr &arr, const array_names &names)
51+
inline std::string name_of(const array_ptr &arr, const array_names &names)
5952
{
6053
auto name_it = names.find(arr);
6154
assert(name_it != names.end());
@@ -66,7 +59,7 @@ std::string name_of(const array_ptr &arr, const array_names &names)
6659
return "'unknown'";
6760
}
6861

69-
void check_writable(const std::vector<array_ptr> &arrays,
62+
inline void check_writable(const std::vector<array_ptr> &arrays,
7063
const array_names &names)
7164
{
7265
for (const auto &arr : arrays) {
@@ -77,7 +70,7 @@ void check_writable(const std::vector<array_ptr> &arrays,
7770
}
7871
}
7972

80-
void check_c_contig(const std::vector<array_ptr> &arrays,
73+
inline void check_c_contig(const std::vector<array_ptr> &arrays,
8174
const array_names &names)
8275
{
8376
for (const auto &arr : arrays) {
@@ -88,7 +81,7 @@ void check_c_contig(const std::vector<array_ptr> &arrays,
8881
}
8982
}
9083

91-
void check_queue(const std::vector<array_ptr> &arrays,
84+
inline void check_queue(const std::vector<array_ptr> &arrays,
9285
const array_names &names,
9386
const sycl::queue &exec_q)
9487
{
@@ -104,7 +97,7 @@ void check_queue(const std::vector<array_ptr> &arrays,
10497
}
10598
}
10699

107-
void check_no_overlap(const array_ptr &input,
100+
inline void check_no_overlap(const array_ptr &input,
108101
const array_ptr &output,
109102
const array_names &names)
110103
{
@@ -121,7 +114,7 @@ void check_no_overlap(const array_ptr &input,
121114
}
122115
}
123116

124-
void check_no_overlap(const std::vector<array_ptr> &inputs,
117+
inline void check_no_overlap(const std::vector<array_ptr> &inputs,
125118
const std::vector<array_ptr> &outputs,
126119
const array_names &names)
127120
{
@@ -132,7 +125,7 @@ void check_no_overlap(const std::vector<array_ptr> &inputs,
132125
}
133126
}
134127

135-
void check_num_dims(const array_ptr &arr,
128+
inline void check_num_dims(const array_ptr &arr,
136129
const size_t ndim,
137130
const array_names &names)
138131
{
@@ -144,7 +137,7 @@ void check_num_dims(const array_ptr &arr,
144137
}
145138
}
146139

147-
void check_max_dims(const array_ptr &arr,
140+
inline void check_max_dims(const array_ptr &arr,
148141
const size_t max_ndim,
149142
const array_names &names)
150143
{
@@ -157,7 +150,7 @@ void check_max_dims(const array_ptr &arr,
157150
}
158151
}
159152

160-
void check_size_at_least(const array_ptr &arr,
153+
inline void check_size_at_least(const array_ptr &arr,
161154
const size_t size,
162155
const array_names &names)
163156
{
@@ -170,7 +163,7 @@ void check_size_at_least(const array_ptr &arr,
170163
}
171164
}
172165

173-
void common_checks(const std::vector<array_ptr> &inputs,
166+
inline void common_checks(const std::vector<array_ptr> &inputs,
174167
const std::vector<array_ptr> &outputs,
175168
const array_names &names)
176169
{
@@ -187,4 +180,4 @@ void common_checks(const std::vector<array_ptr> &inputs,
187180
check_no_overlap(inputs, outputs, names);
188181
}
189182

190-
} // namespace statistics::validation
183+
} // namespace ext::validation

dpnp/backend/extensions/statistics/dispatch_table.hpp dpnp/backend/extensions/common/ext/dispatch_table.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
#include <pybind11/stl.h>
3535
#include <sycl/sycl.hpp>
3636

37-
#include "common.hpp"
37+
#include "ext/common.hpp"
3838

3939
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4040
namespace py = pybind11;
4141

42-
namespace statistics::common
42+
namespace ext::common
4343
{
4444
template <typename T, typename Rest>
4545
struct one_of
@@ -383,4 +383,4 @@ class DispatchTable2
383383
Table2<FnT> table;
384384
};
385385

386-
} // namespace statistics::common
386+
} // namespace ext::common

dpnp/backend/extensions/statistics/validation_utils.hpp dpnp/backend/extensions/common/ext/validation_utils.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
#include "dpctl4pybind11.hpp"
3333

34-
namespace statistics::validation
34+
namespace ext::validation
3535
{
3636
using array_ptr = const dpctl::tensor::usm_ndarray *;
3737
using array_names = std::unordered_map<array_ptr, std::string>;
@@ -67,4 +67,6 @@ void check_size_at_least(const array_ptr &arr,
6767
void common_checks(const std::vector<array_ptr> &inputs,
6868
const std::vector<array_ptr> &outputs,
6969
const array_names &names);
70-
} // namespace statistics::validation
70+
} // namespace ext::validation
71+
72+
#include "ext/details/validation_utils_internal.hpp"

dpnp/backend/extensions/statistics/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@
2727
set(python_module_name _statistics_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
30-
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3130
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3231
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3332
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3433
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
3534
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
3635
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
37-
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
3836
)
3937

4038
pybind11_add_module(${python_module_name} MODULE ${_module_src})
@@ -66,6 +64,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN
6664

6765
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
6866
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
67+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6968

7069
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
7170
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/statistics/bincount.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
using dpctl::tensor::usm_ndarray;
3535

3636
using namespace statistics::histogram;
37-
using namespace statistics::common;
37+
using namespace ext::common;
3838

3939
namespace
4040
{

dpnp/backend/extensions/statistics/bincount.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

31-
#include "dispatch_table.hpp"
3231
#include "dpctl4pybind11.hpp"
32+
#include "ext/dispatch_table.hpp"
3333

3434
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3535

@@ -46,7 +46,7 @@ struct Bincount
4646
const size_t,
4747
const std::vector<sycl::event> &);
4848

49-
common::DispatchTable2<FnT> dispatch_table;
49+
ext::common::DispatchTable2<FnT> dispatch_table;
5050

5151
Bincount();
5252

dpnp/backend/extensions/statistics/histogram.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4343
using dpctl::tensor::usm_ndarray;
4444

4545
using namespace statistics::histogram;
46-
using namespace statistics::common;
46+
using namespace ext::common;
4747

4848
namespace
4949
{

dpnp/backend/extensions/statistics/histogram.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

31-
#include "dispatch_table.hpp"
3231
#include "dpctl4pybind11.hpp"
32+
#include "ext/dispatch_table.hpp"
3333

3434
namespace statistics::histogram
3535
{
@@ -44,7 +44,7 @@ struct Histogram
4444
const size_t,
4545
const std::vector<sycl::event> &);
4646

47-
common::DispatchTable2<FnT> dispatch_table;
47+
ext::common::DispatchTable2<FnT> dispatch_table;
4848

4949
Histogram();
5050

dpnp/backend/extensions/statistics/histogram_common.cpp

+12-11
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,25 @@
3535

3636
#include "histogram_common.hpp"
3737

38-
#include "validation_utils.hpp"
38+
#include "ext/validation_utils.hpp"
3939

4040
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4141
using dpctl::tensor::usm_ndarray;
4242
using dpctl_td_ns::typenum_t;
4343

44-
namespace statistics
45-
{
46-
using common::CeilDiv;
44+
using ext::common::CeilDiv;
45+
46+
using ext::validation::array_names;
47+
using ext::validation::array_ptr;
4748

48-
using validation::array_names;
49-
using validation::array_ptr;
49+
using ext::validation::check_max_dims;
50+
using ext::validation::check_num_dims;
51+
using ext::validation::check_size_at_least;
52+
using ext::validation::common_checks;
53+
using ext::validation::name_of;
5054

51-
using validation::check_max_dims;
52-
using validation::check_num_dims;
53-
using validation::check_size_at_least;
54-
using validation::common_checks;
55-
using validation::name_of;
55+
namespace statistics
56+
{
5657

5758
namespace histogram
5859
{

dpnp/backend/extensions/statistics/histogram_common.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
#include <sycl/sycl.hpp>
2929

30-
#include "common.hpp"
30+
#include "ext/common.hpp"
3131

3232
namespace dpctl::tensor
3333
{
@@ -36,11 +36,12 @@ class usm_ndarray;
3636

3737
using dpctl::tensor::usm_ndarray;
3838

39+
using ext::common::AtomicOp;
40+
using ext::common::IsNan;
41+
using ext::common::Less;
42+
3943
namespace statistics
4044
{
41-
using common::AtomicOp;
42-
using common::IsNan;
43-
using common::Less;
4445

4546
namespace histogram
4647
{

0 commit comments

Comments
 (0)