26
26
#include < pybind11/pybind11.h>
27
27
28
28
// dpctl tensor headers
29
+ #include " utils/output_validation.hpp"
29
30
#include " utils/type_dispatch.hpp"
30
31
#include " utils/type_utils.hpp"
31
32
#include " kernels/alignment.hpp"
@@ -51,7 +52,22 @@ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
51
52
using dpctl_krn_ns::is_aligned;
52
53
using dpctl_krn_ns::required_alignment;
53
54
54
- constexpr int no_of_methods = 2 ; // number of methods of gaussian distribution
55
+ constexpr auto no_of_methods = 2 ; // number of methods of gaussian distribution
56
+
57
+ constexpr auto seq_of_vec_sizes = std::integer_sequence<std::uint8_t , 2 , 4 , 8 , 16 >{};
58
+ constexpr auto vec_sizes_len = seq_of_vec_sizes.size();
59
+ constexpr auto no_of_engines = engine::no_of_engines * vec_sizes_len;
60
+
61
+ template <typename VecSizeT, VecSizeT ...Ints, auto ...Indices>
62
+ inline auto find_vec_size_impl (const VecSizeT vec_size, std::index_sequence<Indices...>) {
63
+ return std::min ({ ((Ints == vec_size) ? Indices : sizeof ...(Indices))... });
64
+ }
65
+
66
+ template <typename VecSizeT, VecSizeT ...Ints>
67
+ int find_vec_size (const VecSizeT vec_size, std::integer_sequence<VecSizeT, Ints...>) {
68
+ auto res = find_vec_size_impl<VecSizeT, Ints...>(vec_size, std::make_index_sequence<sizeof ...(Ints)>{});
69
+ return (res == sizeof ...(Ints)) ? -1 : res;
70
+ }
55
71
56
72
template <typename DataT, typename Method>
57
73
struct DistributorBuilder
@@ -83,7 +99,7 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(engine::EngineBase *engine,
83
99
char *,
84
100
const std::vector<sycl::event> &);
85
101
86
- static gaussian_impl_fn_ptr_t gaussian_dispatch_table[engine:: no_of_engines][dpctl_td_ns::num_types][no_of_methods];
102
+ static gaussian_impl_fn_ptr_t gaussian_dispatch_table[no_of_engines][dpctl_td_ns::num_types][no_of_methods];
87
103
88
104
template <typename EngineT, typename DataT, typename Method, unsigned int items_per_wi>
89
105
class gaussian_kernel ;
@@ -117,7 +133,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
117
133
118
134
using EngineBuilderT = engine::builder::Builder<EngineT>;
119
135
EngineBuilderT eng_builder (engine);
120
- eng_builder.print (); // TODO: remove
136
+ // eng_builder.print(); // TODO: remove
121
137
122
138
using DistributorBuilderT = DistributorBuilder<DataT, Method>;
123
139
DistributorBuilderT dist_builder (mean, stddev);
@@ -154,6 +170,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
154
170
155
171
std::pair<sycl::event, sycl::event> gaussian (engine::EngineBase *engine,
156
172
const std::uint8_t method_id,
173
+ const std::uint8_t vec_size,
157
174
const double mean,
158
175
const double stddev,
159
176
const std::uint64_t n,
@@ -176,15 +193,10 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
176
193
}
177
194
178
195
// ensure that output is ample enough to accommodate all elements
179
- auto res_offsets = res.get_minmax_offsets ();
180
- // destination must be ample enough to accommodate all elements
181
- {
182
- size_t range =
183
- static_cast <size_t >(res_offsets.second - res_offsets.first );
184
- if (range + 1 < res_nelems) {
185
- throw py::value_error (
186
- " Destination array can not accommodate all the elements of source array." );
187
- }
196
+ dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (res, res_nelems);
197
+
198
+ if (!dpctl::utils::queues_are_compatible (exec_q, {res})) {
199
+ throw py::value_error (" Execution queue is not compatible with the allocation queue" );
188
200
}
189
201
190
202
bool is_res_c_contig = res.is_c_contiguous ();
@@ -201,6 +213,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
201
213
throw std::runtime_error (" Unknown method=" + std::to_string (method_id) + " for gaussian distribution." );
202
214
}
203
215
216
+ int vec_size_id = find_vec_size (vec_size, seq_of_vec_sizes);
217
+ if (vec_size_id < 0 ) {
218
+ throw std::runtime_error (" Vector size=" + std::to_string (vec_size) + " is out of supported range" );
219
+ }
220
+ enginge_id = enginge_id * vec_sizes_len + vec_size_id;
221
+
204
222
auto array_types = dpctl_td_ns::usm_ndarray_types ();
205
223
int res_type_id = array_types.typenum_to_lookup_id (res.get_typenum ());
206
224
@@ -232,7 +250,7 @@ struct GaussianContigFactory
232
250
233
251
void init_gaussian_dispatch_3d_table (void )
234
252
{
235
- dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine:: no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
236
- contig.populate (gaussian_dispatch_table);
253
+ dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
254
+ contig.populate (gaussian_dispatch_table, seq_of_vec_sizes );
237
255
}
238
256
} // dpnp::backend::ext::rng::device
0 commit comments