38
38
39
39
#include " histogram_common.hpp"
40
40
41
+ #include " validation_utils.hpp"
42
+
41
43
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
42
44
using dpctl::tensor::usm_ndarray;
43
45
using dpctl_td_ns::typenum_t ;
@@ -46,6 +48,15 @@ namespace statistics
46
48
{
47
49
using common::CeilDiv;
48
50
51
+ using validation::array_names;
52
+ using validation::array_ptr;
53
+
54
+ using validation::check_max_dims;
55
+ using validation::check_num_dims;
56
+ using validation::check_size_at_least;
57
+ using validation::common_checks;
58
+ using validation::name_of;
59
+
49
60
namespace histogram
50
61
{
51
62
@@ -55,11 +66,9 @@ void validate(const usm_ndarray &sample,
55
66
const usm_ndarray &histogram)
56
67
{
57
68
auto exec_q = sample.get_queue ();
58
- using array_ptr = const usm_ndarray *;
59
69
60
70
std::vector<array_ptr> arrays{&sample, &histogram};
61
- std::unordered_map<array_ptr, std::string> names = {
62
- {arrays[0 ], " sample" }, {arrays[1 ], " histogram" }};
71
+ array_names names = {{arrays[0 ], " sample" }, {arrays[1 ], " histogram" }};
63
72
64
73
array_ptr bins_ptr = nullptr ;
65
74
@@ -77,117 +86,48 @@ void validate(const usm_ndarray &sample,
77
86
names.insert ({weights_ptr, " weights" });
78
87
}
79
88
80
- auto get_name = [&](const array_ptr &arr) {
81
- auto name_it = names.find (arr);
82
- assert (name_it != names.end ());
83
-
84
- return " '" + name_it->second + " '" ;
85
- };
86
-
87
- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (histogram);
88
-
89
- auto unequal_queue =
90
- std::find_if (arrays.cbegin (), arrays.cend (), [&](const array_ptr &arr) {
91
- return arr->get_queue () != exec_q;
92
- });
93
-
94
- if (unequal_queue != arrays.cend ()) {
95
- throw py::value_error (
96
- get_name (*unequal_queue) +
97
- " parameter has incompatible queue with parameter " +
98
- get_name (&sample));
99
- }
100
-
101
- auto non_contig_array =
102
- std::find_if (arrays.cbegin (), arrays.cend (), [&](const array_ptr &arr) {
103
- return !arr->is_c_contiguous ();
104
- });
89
+ common_checks ({&sample, bins.has_value () ? &bins.value () : nullptr ,
90
+ weights.has_value () ? &weights.value () : nullptr },
91
+ {&histogram}, names);
105
92
106
- if (non_contig_array != arrays.cend ()) {
107
- throw py::value_error (get_name (*non_contig_array) +
108
- " parameter is not c-contiguos" );
109
- }
93
+ check_size_at_least (bins_ptr, 2 , names);
110
94
111
- auto check_overlaping = [&](const array_ptr &first,
112
- const array_ptr &second) {
113
- if (first == nullptr || second == nullptr ) {
114
- return ;
115
- }
116
-
117
- const auto &overlap = dpctl::tensor::overlap::MemoryOverlap ();
118
-
119
- if (overlap (*first, *second)) {
120
- throw py::value_error (get_name (first) +
121
- " has overlapping memory segments with " +
122
- get_name (second));
123
- }
124
- };
125
-
126
- check_overlaping (&sample, &histogram);
127
- check_overlaping (bins_ptr, &histogram);
128
- check_overlaping (weights_ptr, &histogram);
129
-
130
- if (bins_ptr && bins_ptr->get_size () < 2 ) {
131
- throw py::value_error (get_name (bins_ptr) +
132
- " parameter must have at least 2 elements" );
133
- }
134
-
135
- if (histogram.get_size () < 1 ) {
136
- throw py::value_error (get_name (&histogram) +
137
- " parameter must have at least 1 element" );
138
- }
139
-
140
- if (histogram.get_ndim () != 1 ) {
141
- throw py::value_error (get_name (&histogram) +
142
- " parameter must be 1d. Actual " +
143
- std::to_string (histogram.get_ndim ()) + " d" );
144
- }
95
+ check_size_at_least (&histogram, 1 , names);
96
+ check_num_dims (&histogram, 1 , names);
145
97
146
98
if (weights_ptr) {
147
- if (weights_ptr->get_ndim () != 1 ) {
148
- throw py::value_error (
149
- get_name (weights_ptr) + " parameter must be 1d. Actual " +
150
- std::to_string (weights_ptr->get_ndim ()) + " d" );
151
- }
99
+ check_num_dims (weights_ptr, 1 , names);
152
100
153
101
auto sample_size = sample.get_size ();
154
102
auto weights_size = weights_ptr->get_size ();
155
103
if (sample.get_size () != weights_ptr->get_size ()) {
156
- throw py::value_error (
157
- get_name (&sample) + " size (" + std::to_string (sample_size) +
158
- " ) and " + get_name (weights_ptr) + " size (" +
159
- std::to_string (weights_size) + " )" + " must match" );
104
+ throw py::value_error (name_of (&sample, names) + " size (" +
105
+ std::to_string (sample_size) + " ) and " +
106
+ name_of (weights_ptr, names) + " size (" +
107
+ std::to_string (weights_size) + " )" +
108
+ " must match" );
160
109
}
161
110
}
162
111
163
- if (sample.get_ndim () > 2 ) {
164
- throw py::value_error (
165
- get_name (&sample) +
166
- " parameter must have no more than 2 dimensions. Actual " +
167
- std::to_string (sample.get_ndim ()) + " d" );
168
- }
112
+ check_max_dims (&sample, 2 , names);
169
113
170
114
if (sample.get_ndim () == 1 ) {
171
- if (bins_ptr != nullptr && bins_ptr->get_ndim () != 1 ) {
172
- throw py::value_error (get_name (&sample) + " parameter is 1d, but " +
173
- get_name (bins_ptr) + " is " +
174
- std::to_string (bins_ptr->get_ndim ()) + " d" );
175
- }
115
+ check_num_dims (bins_ptr, 1 , names);
176
116
}
177
117
else if (sample.get_ndim () == 2 ) {
178
118
auto sample_count = sample.get_shape (0 );
179
119
auto expected_dims = sample.get_shape (1 );
180
120
181
121
if (bins_ptr != nullptr && bins_ptr->get_ndim () != expected_dims) {
182
- throw py::value_error (get_name (&sample) + " parameter has shape { " +
183
- std::to_string (sample_count) + " x " +
184
- std::to_string (expected_dims ) + " } " +
185
- " , so " + get_name (bins_ptr) +
186
- " parameter expected to be " +
187
- std::to_string (expected_dims) +
188
- " d. "
189
- " Actual " +
190
- std::to_string (bins->get_ndim ()) + " d" );
122
+ throw py::value_error (
123
+ name_of (&sample, names) + " parameter has shape { " +
124
+ std::to_string (sample_count ) + " x " +
125
+ std::to_string (expected_dims) + " } " + " , so " +
126
+ name_of (bins_ptr, names) + " parameter expected to be " +
127
+ std::to_string (expected_dims) +
128
+ " d. "
129
+ " Actual " +
130
+ std::to_string (bins->get_ndim ()) + " d" );
191
131
}
192
132
}
193
133
@@ -199,17 +139,17 @@ void validate(const usm_ndarray &sample,
199
139
200
140
if (histogram.get_size () != expected_hist_size) {
201
141
throw py::value_error (
202
- get_name (&histogram) + " and " + get_name (bins_ptr) +
203
- " shape mismatch. " + get_name (&histogram) +
204
- " expected to have size = " +
142
+ name_of (&histogram, names ) + " and " +
143
+ name_of (bins_ptr, names) + " shape mismatch. " +
144
+ name_of (&histogram, names) + " expected to have size = " +
205
145
std::to_string (expected_hist_size) + " . Actual " +
206
146
std::to_string (histogram.get_size ()));
207
147
}
208
148
}
209
149
210
150
int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
211
151
if (histogram.get_size () > max_hist_size) {
212
- throw py::value_error (get_name (&histogram) +
152
+ throw py::value_error (name_of (&histogram, names ) +
213
153
" parameter size expected to be less than " +
214
154
std::to_string (max_hist_size) + " . Actual " +
215
155
std::to_string (histogram.get_size ()));
@@ -225,7 +165,7 @@ void validate(const usm_ndarray &sample,
225
165
if (!_64bit_atomics) {
226
166
auto device_name = device.get_info <sycl::info::device::name>();
227
167
throw py::value_error (
228
- get_name (&histogram) +
168
+ name_of (&histogram, names ) +
229
169
" parameter has 64-bit type, but 64-bit atomics " +
230
170
" are not supported for " + device_name);
231
171
}
0 commit comments