Skip to content

Commit 0d81755

Browse files
committed
Exposed vector size of an engine
1 parent f16f5cc commit 0d81755

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ class Dispatch3DTableBuilder
7777
Dispatch3DTableBuilder() = default;
7878
~Dispatch3DTableBuilder() = default;
7979

80-
void populate(funcPtrT table[][_no_of_types][_no_of_methods]) const
80+
template <std::uint8_t... VecSizes>
81+
void populate(funcPtrT table[][_no_of_types][_no_of_methods], std::integer_sequence<std::uint8_t, VecSizes...>) const
8182
{
82-
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8>>(),
83-
table_per_type_and_method<mkl_rng_dev::philox4x32x10<8>>(),
84-
table_per_type_and_method<mkl_rng_dev::mcg31m1<8>>(),
85-
table_per_type_and_method<mkl_rng_dev::mcg59<8>>()};
83+
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
84+
table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>>()...,
85+
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
86+
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
8687
assert(map_by_engine.size() == _no_of_engines);
8788

8889
std::uint16_t engine_id = 0;

dpnp/backend/extensions/rng/device/engine/builder/base_builder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class BaseBuilder {
111111

112112
// TODO: remove
113113
void print() {
114+
std::cout << "vector size = " << std::to_string(EngineT::vec_size) << std::endl;
114115
std::cout << "list_of_seeds: ";
115116
for (auto &val: seeds) {
116117
std::cout << std::to_string(val) << ", ";

dpnp/backend/extensions/rng/device/gaussian.cpp

+32-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29+
#include "utils/output_validation.hpp"
2930
#include "utils/type_dispatch.hpp"
3031
#include "utils/type_utils.hpp"
3132
#include "kernels/alignment.hpp"
@@ -51,7 +52,22 @@ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
5152
using dpctl_krn_ns::is_aligned;
5253
using dpctl_krn_ns::required_alignment;
5354

54-
constexpr int no_of_methods = 2; // number of methods of gaussian distribution
55+
constexpr auto no_of_methods = 2; // number of methods of gaussian distribution
56+
57+
constexpr auto seq_of_vec_sizes = std::integer_sequence<std::uint8_t, 2, 4, 8, 16>{};
58+
constexpr auto vec_sizes_len = seq_of_vec_sizes.size();
59+
constexpr auto no_of_engines = engine::no_of_engines * vec_sizes_len;
60+
61+
template <typename VecSizeT, VecSizeT ...Ints, auto ...Indices>
62+
inline auto find_vec_size_impl(const VecSizeT vec_size, std::index_sequence<Indices...>) {
63+
return std::min({ ((Ints == vec_size) ? Indices : sizeof...(Indices))... });
64+
}
65+
66+
template <typename VecSizeT, VecSizeT ...Ints>
67+
int find_vec_size(const VecSizeT vec_size, std::integer_sequence<VecSizeT, Ints...>) {
68+
auto res = find_vec_size_impl<VecSizeT, Ints...>(vec_size, std::make_index_sequence<sizeof...(Ints)>{});
69+
return (res == sizeof...(Ints)) ? -1 : res;
70+
}
5571

5672
template <typename DataT, typename Method>
5773
struct DistributorBuilder
@@ -83,7 +99,7 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(engine::EngineBase *engine,
8399
char *,
84100
const std::vector<sycl::event> &);
85101

86-
static gaussian_impl_fn_ptr_t gaussian_dispatch_table[engine::no_of_engines][dpctl_td_ns::num_types][no_of_methods];
102+
static gaussian_impl_fn_ptr_t gaussian_dispatch_table[no_of_engines][dpctl_td_ns::num_types][no_of_methods];
87103

88104
template <typename EngineT, typename DataT, typename Method, unsigned int items_per_wi>
89105
class gaussian_kernel;
@@ -117,7 +133,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
117133

118134
using EngineBuilderT = engine::builder::Builder<EngineT>;
119135
EngineBuilderT eng_builder(engine);
120-
eng_builder.print(); // TODO: remove
136+
// eng_builder.print(); // TODO: remove
121137

122138
using DistributorBuilderT = DistributorBuilder<DataT, Method>;
123139
DistributorBuilderT dist_builder(mean, stddev);
@@ -154,6 +170,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
154170

155171
std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
156172
const std::uint8_t method_id,
173+
const std::uint8_t vec_size,
157174
const double mean,
158175
const double stddev,
159176
const std::uint64_t n,
@@ -176,15 +193,10 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
176193
}
177194

178195
// ensure that output is ample enough to accommodate all elements
179-
auto res_offsets = res.get_minmax_offsets();
180-
// destination must be ample enough to accommodate all elements
181-
{
182-
size_t range =
183-
static_cast<size_t>(res_offsets.second - res_offsets.first);
184-
if (range + 1 < res_nelems) {
185-
throw py::value_error(
186-
"Destination array can not accommodate all the elements of source array.");
187-
}
196+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(res, res_nelems);
197+
198+
if (!dpctl::utils::queues_are_compatible(exec_q, {res})) {
199+
throw py::value_error("Execution queue is not compatible with the allocation queue");
188200
}
189201

190202
bool is_res_c_contig = res.is_c_contiguous();
@@ -201,6 +213,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
201213
throw std::runtime_error("Unknown method=" + std::to_string(method_id) + " for gaussian distribution.");
202214
}
203215

216+
int vec_size_id = find_vec_size(vec_size, seq_of_vec_sizes);
217+
if (vec_size_id < 0) {
218+
throw std::runtime_error("Vector size=" + std::to_string(vec_size) + " is out of supported range");
219+
}
220+
enginge_id = enginge_id * vec_sizes_len + vec_size_id;
221+
204222
auto array_types = dpctl_td_ns::usm_ndarray_types();
205223
int res_type_id = array_types.typenum_to_lookup_id(res.get_typenum());
206224

@@ -232,7 +250,7 @@ struct GaussianContigFactory
232250

233251
void init_gaussian_dispatch_3d_table(void)
234252
{
235-
dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
236-
contig.populate(gaussian_dispatch_table);
253+
dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
254+
contig.populate(gaussian_dispatch_table, seq_of_vec_sizes);
237255
}
238256
} // dpnp::backend::ext::rng::device

dpnp/backend/extensions/rng/device/gaussian.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace dpnp::backend::ext::rng::device
3434
{
3535
extern std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
3636
const std::uint8_t method_id,
37+
const std::uint8_t vec_size,
3738
const double mean,
3839
const double stddev,
3940
const std::uint64_t n,

dpnp/backend/extensions/rng/device/rng_py.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ PYBIND11_MODULE(_rng_dev_impl, m)
100100
m.def("_gaussian", &rng_dev_ext::gaussian,
101101
"",
102102
py::arg("engine"),
103-
py::arg("method"), py::arg("mean"), py::arg("stddev"),
103+
py::arg("method_id"), py::arg("vec_size"),
104+
py::arg("mean"), py::arg("stddev"),
104105
py::arg("n"), py::arg("res"),
105106
py::arg("depends") = py::list());
106107
}

0 commit comments

Comments
 (0)