23
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
24
// *****************************************************************************
25
25
26
- #include " validation_utils.hpp"
26
+ #include " ext/ validation_utils.hpp"
27
27
#include " utils/memory_overlap.hpp"
28
28
29
- using statistics::validation::array_names;
30
- using statistics::validation::array_ptr;
31
-
32
- namespace
29
+ namespace ext ::validation
33
30
{
34
-
35
- sycl::queue get_queue (const std::vector<array_ptr> &inputs,
36
- const std::vector<array_ptr> &outputs)
31
+ inline sycl::queue get_queue (const std::vector<array_ptr> &inputs,
32
+ const std::vector<array_ptr> &outputs)
37
33
{
38
34
auto it = std::find_if (inputs.cbegin (), inputs.cend (),
39
35
[](const array_ptr &arr) { return arr != nullptr ; });
@@ -51,11 +47,8 @@ sycl::queue get_queue(const std::vector<array_ptr> &inputs,
51
47
52
48
throw py::value_error (" No input or output arrays found" );
53
49
}
54
- } // namespace
55
50
56
- namespace statistics ::validation
57
- {
58
- std::string name_of (const array_ptr &arr, const array_names &names)
51
+ inline std::string name_of (const array_ptr &arr, const array_names &names)
59
52
{
60
53
auto name_it = names.find (arr);
61
54
assert (name_it != names.end ());
@@ -66,8 +59,8 @@ std::string name_of(const array_ptr &arr, const array_names &names)
66
59
return " 'unknown'" ;
67
60
}
68
61
69
- void check_writable (const std::vector<array_ptr> &arrays,
70
- const array_names &names)
62
+ inline void check_writable (const std::vector<array_ptr> &arrays,
63
+ const array_names &names)
71
64
{
72
65
for (const auto &arr : arrays) {
73
66
if (arr != nullptr && !arr->is_writable ()) {
@@ -77,8 +70,8 @@ void check_writable(const std::vector<array_ptr> &arrays,
77
70
}
78
71
}
79
72
80
- void check_c_contig (const std::vector<array_ptr> &arrays,
81
- const array_names &names)
73
+ inline void check_c_contig (const std::vector<array_ptr> &arrays,
74
+ const array_names &names)
82
75
{
83
76
for (const auto &arr : arrays) {
84
77
if (arr != nullptr && !arr->is_c_contiguous ()) {
@@ -88,9 +81,9 @@ void check_c_contig(const std::vector<array_ptr> &arrays,
88
81
}
89
82
}
90
83
91
- void check_queue (const std::vector<array_ptr> &arrays,
92
- const array_names &names,
93
- const sycl::queue &exec_q)
84
+ inline void check_queue (const std::vector<array_ptr> &arrays,
85
+ const array_names &names,
86
+ const sycl::queue &exec_q)
94
87
{
95
88
auto unequal_queue =
96
89
std::find_if (arrays.cbegin (), arrays.cend (), [&](const array_ptr &arr) {
@@ -104,9 +97,9 @@ void check_queue(const std::vector<array_ptr> &arrays,
104
97
}
105
98
}
106
99
107
- void check_no_overlap (const array_ptr &input,
108
- const array_ptr &output,
109
- const array_names &names)
100
+ inline void check_no_overlap (const array_ptr &input,
101
+ const array_ptr &output,
102
+ const array_names &names)
110
103
{
111
104
if (input == nullptr || output == nullptr ) {
112
105
return ;
@@ -121,9 +114,9 @@ void check_no_overlap(const array_ptr &input,
121
114
}
122
115
}
123
116
124
- void check_no_overlap (const std::vector<array_ptr> &inputs,
125
- const std::vector<array_ptr> &outputs,
126
- const array_names &names)
117
+ inline void check_no_overlap (const std::vector<array_ptr> &inputs,
118
+ const std::vector<array_ptr> &outputs,
119
+ const array_names &names)
127
120
{
128
121
for (const auto &input : inputs) {
129
122
for (const auto &output : outputs) {
@@ -132,9 +125,9 @@ void check_no_overlap(const std::vector<array_ptr> &inputs,
132
125
}
133
126
}
134
127
135
- void check_num_dims (const array_ptr &arr,
136
- const size_t ndim,
137
- const array_names &names)
128
+ inline void check_num_dims (const array_ptr &arr,
129
+ const size_t ndim,
130
+ const array_names &names)
138
131
{
139
132
size_t arr_n_dim = arr != nullptr ? arr->get_ndim () : 0 ;
140
133
if (arr != nullptr && arr_n_dim != ndim) {
@@ -144,9 +137,9 @@ void check_num_dims(const array_ptr &arr,
144
137
}
145
138
}
146
139
147
- void check_max_dims (const array_ptr &arr,
148
- const size_t max_ndim,
149
- const array_names &names)
140
+ inline void check_max_dims (const array_ptr &arr,
141
+ const size_t max_ndim,
142
+ const array_names &names)
150
143
{
151
144
size_t arr_n_dim = arr != nullptr ? arr->get_ndim () : 0 ;
152
145
if (arr != nullptr && arr_n_dim > max_ndim) {
@@ -157,9 +150,9 @@ void check_max_dims(const array_ptr &arr,
157
150
}
158
151
}
159
152
160
- void check_size_at_least (const array_ptr &arr,
161
- const size_t size,
162
- const array_names &names)
153
+ inline void check_size_at_least (const array_ptr &arr,
154
+ const size_t size,
155
+ const array_names &names)
163
156
{
164
157
size_t arr_size = arr != nullptr ? arr->get_size () : 0 ;
165
158
if (arr != nullptr && arr_size < size) {
@@ -170,9 +163,9 @@ void check_size_at_least(const array_ptr &arr,
170
163
}
171
164
}
172
165
173
- void common_checks (const std::vector<array_ptr> &inputs,
174
- const std::vector<array_ptr> &outputs,
175
- const array_names &names)
166
+ inline void common_checks (const std::vector<array_ptr> &inputs,
167
+ const std::vector<array_ptr> &outputs,
168
+ const array_names &names)
176
169
{
177
170
check_writable (outputs, names);
178
171
@@ -187,4 +180,4 @@ void common_checks(const std::vector<array_ptr> &inputs,
187
180
check_no_overlap (inputs, outputs, names);
188
181
}
189
182
190
- } // namespace statistics ::validation
183
+ } // namespace ext ::validation
0 commit comments