Skip to content

Commit 199f13b

Browse files
committed
Added uniform distribution
1 parent 9d9540e commit 199f13b

File tree

7 files changed

+387
-49
lines changed

7 files changed

+387
-49
lines changed

dpnp/backend/extensions/rng/device/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _rng_dev_impl)
2828
pybind11_add_module(${python_module_name} MODULE
2929
rng_py.cpp
3030
gaussian.cpp
31+
uniform.cpp
3132
)
3233

3334
if (WIN32)

dpnp/backend/extensions/rng/device/common_impl.hpp

+2-17
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,7 @@
3030
#include <oneapi/mkl/rng/device.hpp>
3131
#include <sycl/sycl.hpp>
3232

33-
namespace dpnp
34-
{
35-
namespace backend
36-
{
37-
namespace ext
38-
{
39-
namespace rng
40-
{
41-
namespace device
42-
{
43-
namespace details
33+
namespace dpnp::backend::ext::rng::device::details
4434
{
4535
namespace py = pybind11;
4636

@@ -129,9 +119,4 @@ struct RngContigFunctor
129119
}
130120
}
131121
};
132-
} // namespace details
133-
} // namespace device
134-
} // namespace rng
135-
} // namespace ext
136-
} // namespace backend
137-
} // namespace dpnp
122+
} // namespace dpnp::backend::ext::rng::device::details

dpnp/backend/extensions/rng/device/dispatch/matrix.hpp

+20-4
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,35 @@ struct GaussianTypePairSupportFactory
4949
TypePairDefinedEntry<T,
5050
double,
5151
M,
52-
mkl_rng_dev::gaussian_method::by_default>,
52+
mkl_rng_dev::gaussian_method::box_muller2>,
5353
TypePairDefinedEntry<T,
54-
double,
54+
float,
5555
M,
5656
mkl_rng_dev::gaussian_method::box_muller2>,
57+
// fall-through
58+
dpctl_td_ns::NotDefinedEntry>::is_defined;
59+
};
60+
61+
template <typename T, typename M>
62+
struct UniformTypePairSupportFactory
63+
{
64+
static constexpr bool is_defined = std::disjunction<
65+
TypePairDefinedEntry<T,
66+
double,
67+
M,
68+
mkl_rng_dev::uniform_method::standard>,
69+
TypePairDefinedEntry<T,
70+
double,
71+
M,
72+
mkl_rng_dev::uniform_method::accurate>,
5773
TypePairDefinedEntry<T,
5874
float,
5975
M,
60-
mkl_rng_dev::gaussian_method::by_default>,
76+
mkl_rng_dev::uniform_method::standard>,
6177
TypePairDefinedEntry<T,
6278
float,
6379
M,
64-
mkl_rng_dev::gaussian_method::box_muller2>,
80+
mkl_rng_dev::uniform_method::accurate>,
6581
// fall-through
6682
dpctl_td_ns::NotDefinedEntry>::is_defined;
6783
};

dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp

+22-26
Original file line numberDiff line numberDiff line change
@@ -40,37 +40,34 @@ template <typename funcPtrT,
4040
class Dispatch3DTableBuilder
4141
{
4242
private:
43-
template <typename E, typename T>
43+
template <typename E, typename T, typename... Methods>
4444
const std::vector<funcPtrT> row_per_method() const
4545
{
4646
std::vector<funcPtrT> per_method = {
47-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}
48-
.get(),
49-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}
50-
.get(),
47+
factory<funcPtrT, E, T, Methods>{}.get()...,
5148
};
5249
assert(per_method.size() == _no_of_methods);
5350
return per_method;
5451
}
5552

56-
template <typename E>
53+
template <typename E, typename... Methods>
5754
auto table_per_type_and_method() const
5855
{
5956
std::vector<std::vector<funcPtrT>> table_by_type = {
60-
row_per_method<E, bool>(),
61-
row_per_method<E, int8_t>(),
62-
row_per_method<E, uint8_t>(),
63-
row_per_method<E, int16_t>(),
64-
row_per_method<E, uint16_t>(),
65-
row_per_method<E, int32_t>(),
66-
row_per_method<E, uint32_t>(),
67-
row_per_method<E, int64_t>(),
68-
row_per_method<E, uint64_t>(),
69-
row_per_method<E, sycl::half>(),
70-
row_per_method<E, float>(),
71-
row_per_method<E, double>(),
72-
row_per_method<E, std::complex<float>>(),
73-
row_per_method<E, std::complex<double>>()};
57+
row_per_method<E, bool, Methods...>(),
58+
row_per_method<E, int8_t, Methods...>(),
59+
row_per_method<E, uint8_t, Methods...>(),
60+
row_per_method<E, int16_t, Methods...>(),
61+
row_per_method<E, uint16_t, Methods...>(),
62+
row_per_method<E, int32_t, Methods...>(),
63+
row_per_method<E, uint32_t, Methods...>(),
64+
row_per_method<E, int64_t, Methods...>(),
65+
row_per_method<E, uint64_t, Methods...>(),
66+
row_per_method<E, sycl::half, Methods...>(),
67+
row_per_method<E, float, Methods...>(),
68+
row_per_method<E, double, Methods...>(),
69+
row_per_method<E, std::complex<float>, Methods...>(),
70+
row_per_method<E, std::complex<double>, Methods...>()};
7471
assert(table_by_type.size() == _no_of_types);
7572
return table_by_type;
7673
}
@@ -79,16 +76,15 @@ class Dispatch3DTableBuilder
7976
Dispatch3DTableBuilder() = default;
8077
~Dispatch3DTableBuilder() = default;
8178

82-
template <std::uint8_t... VecSizes>
79+
template <typename... Methods, std::uint8_t... VecSizes>
8380
void populate(funcPtrT table[][_no_of_types][_no_of_methods],
8481
std::integer_sequence<std::uint8_t, VecSizes...>) const
8582
{
8683
const auto map_by_engine = {
87-
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
88-
table_per_type_and_method<
89-
mkl_rng_dev::philox4x32x10<VecSizes>>()...,
90-
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
91-
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
84+
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>, Methods...>()...,
85+
table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>, Methods...>()...,
86+
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>, Methods...>()...,
87+
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>, Methods...>()...};
9288
assert(map_by_engine.size() == _no_of_engines);
9389

9490
std::uint16_t engine_id = 0;

dpnp/backend/extensions/rng/device/gaussian.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
5151
using dpctl_krn_ns::is_aligned;
5252
using dpctl_krn_ns::required_alignment;
5353

54-
constexpr auto no_of_methods = 2; // number of methods of gaussian distribution
54+
constexpr auto no_of_methods = 1; // number of methods of gaussian distribution
5555

5656
constexpr auto seq_of_vec_sizes =
5757
std::integer_sequence<std::uint8_t, 2, 4, 8, 16>{};
@@ -291,6 +291,6 @@ void init_gaussian_dispatch_3d_table(void)
291291
GaussianContigFactory, no_of_engines,
292292
dpctl_td_ns::num_types, no_of_methods>
293293
contig;
294-
contig.populate(gaussian_dispatch_table, seq_of_vec_sizes);
294+
contig.populate<mkl_rng_dev::gaussian_method::box_muller2>(gaussian_dispatch_table, seq_of_vec_sizes);
295295
}
296296
} // namespace dpnp::backend::ext::rng::device

0 commit comments

Comments
 (0)