2626#include < pybind11/pybind11.h>
2727
2828// dpctl tensor headers
29+ #include " utils/output_validation.hpp"
2930#include " utils/type_dispatch.hpp"
3031#include " utils/type_utils.hpp"
3132#include " kernels/alignment.hpp"
@@ -51,7 +52,22 @@ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
5152using dpctl_krn_ns::is_aligned;
5253using dpctl_krn_ns::required_alignment;
5354
54- constexpr int no_of_methods = 2 ; // number of methods of gaussian distribution
55+ constexpr auto no_of_methods = 2 ; // number of methods of gaussian distribution
56+
57+ constexpr auto seq_of_vec_sizes = std::integer_sequence<std::uint8_t , 2 , 4 , 8 , 16 >{};
58+ constexpr auto vec_sizes_len = seq_of_vec_sizes.size();
59+ constexpr auto no_of_engines = engine::no_of_engines * vec_sizes_len;
60+
61+ template <typename VecSizeT, VecSizeT ...Ints, auto ...Indices>
62+ inline auto find_vec_size_impl (const VecSizeT vec_size, std::index_sequence<Indices...>) {
63+ return std::min ({ ((Ints == vec_size) ? Indices : sizeof ...(Indices))... });
64+ }
65+
66+ template <typename VecSizeT, VecSizeT ...Ints>
67+ int find_vec_size (const VecSizeT vec_size, std::integer_sequence<VecSizeT, Ints...>) {
68+ auto res = find_vec_size_impl<VecSizeT, Ints...>(vec_size, std::make_index_sequence<sizeof ...(Ints)>{});
69+ return (res == sizeof ...(Ints)) ? -1 : res;
70+ }
5571
5672template <typename DataT, typename Method>
5773struct DistributorBuilder
@@ -83,7 +99,7 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(engine::EngineBase *engine,
8399 char *,
84100 const std::vector<sycl::event> &);
85101
86- static gaussian_impl_fn_ptr_t gaussian_dispatch_table[engine:: no_of_engines][dpctl_td_ns::num_types][no_of_methods];
102+ static gaussian_impl_fn_ptr_t gaussian_dispatch_table[no_of_engines][dpctl_td_ns::num_types][no_of_methods];
87103
88104template <typename EngineT, typename DataT, typename Method, unsigned int items_per_wi>
89105class gaussian_kernel ;
@@ -117,7 +133,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
117133
118134 using EngineBuilderT = engine::builder::Builder<EngineT>;
119135 EngineBuilderT eng_builder (engine);
120- eng_builder.print (); // TODO: remove
136+ // eng_builder.print(); // TODO: remove
121137
122138 using DistributorBuilderT = DistributorBuilder<DataT, Method>;
123139 DistributorBuilderT dist_builder (mean, stddev);
@@ -154,6 +170,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
154170
155171std::pair<sycl::event, sycl::event> gaussian (engine::EngineBase *engine,
156172 const std::uint8_t method_id,
173+ const std::uint8_t vec_size,
157174 const double mean,
158175 const double stddev,
159176 const std::uint64_t n,
@@ -176,15 +193,10 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
176193 }
177194
178195 // ensure that output is ample enough to accommodate all elements
179- auto res_offsets = res.get_minmax_offsets ();
180- // destination must be ample enough to accommodate all elements
181- {
182- size_t range =
183- static_cast <size_t >(res_offsets.second - res_offsets.first );
184- if (range + 1 < res_nelems) {
185- throw py::value_error (
186- " Destination array can not accommodate all the elements of source array." );
187- }
196+ dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (res, res_nelems);
197+
198+ if (!dpctl::utils::queues_are_compatible (exec_q, {res})) {
199+ throw py::value_error (" Execution queue is not compatible with the allocation queue" );
188200 }
189201
190202 bool is_res_c_contig = res.is_c_contiguous ();
@@ -201,6 +213,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
201213 throw std::runtime_error (" Unknown method=" + std::to_string (method_id) + " for gaussian distribution." );
202214 }
203215
216+ int vec_size_id = find_vec_size (vec_size, seq_of_vec_sizes);
217+ if (vec_size_id < 0 ) {
218+ throw std::runtime_error (" Vector size=" + std::to_string (vec_size) + " is out of supported range" );
219+ }
220+ enginge_id = enginge_id * vec_sizes_len + vec_size_id;
221+
204222 auto array_types = dpctl_td_ns::usm_ndarray_types ();
205223 int res_type_id = array_types.typenum_to_lookup_id (res.get_typenum ());
206224
@@ -232,7 +250,7 @@ struct GaussianContigFactory
232250
233251void init_gaussian_dispatch_3d_table (void )
234252{
235- dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine:: no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
236- contig.populate (gaussian_dispatch_table);
253+ dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
254+ contig.populate (gaussian_dispatch_table, seq_of_vec_sizes );
237255}
238256} // dpnp::backend::ext::rng::device
0 commit comments