Skip to content

Commit 380badd

Browse files
committed
Applied pre-commit formatting rules
1 parent 14cc8a7 commit 380badd

22 files changed

+460
-616
lines changed

dpnp/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.p
5858
add_subdirectory(backend)
5959
add_subdirectory(backend/extensions/blas)
6060
add_subdirectory(backend/extensions/lapack)
61-
add_subdirectory(backend/extensions/rng)
6261
add_subdirectory(backend/extensions/rng/device)
6362
add_subdirectory(backend/extensions/vm)
6463
add_subdirectory(backend/extensions/sycl_ext)

dpnp/backend/extensions/rng/CMakeLists.txt

-74
This file was deleted.

dpnp/backend/extensions/rng/device/common_impl.hpp

+32-14
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30-
#include <sycl/sycl.hpp>
3130
#include <oneapi/mkl/rng/device.hpp>
31+
#include <sycl/sycl.hpp>
3232

3333
namespace dpnp
3434
{
@@ -57,11 +57,14 @@ struct RngContigFunctor
5757

5858
EngineBuilderT engine_;
5959
DistributorBuilderT distr_;
60-
DataT * const res_ = nullptr;
60+
DataT *const res_ = nullptr;
6161
const std::size_t nelems_;
6262

6363
public:
64-
RngContigFunctor(EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const std::size_t n_elems)
64+
RngContigFunctor(EngineBuilderT &engine,
65+
DistributorBuilderT &distr,
66+
DataT *res,
67+
const std::size_t n_elems)
6568
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
6669
{
6770
}
@@ -82,31 +85,46 @@ struct RngContigFunctor
8285
DistrT distr = distr_();
8386

8487
if constexpr (enable_sg_load) {
85-
const std::size_t base = vi_per_wi * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
88+
const std::size_t base =
89+
vi_per_wi * (nd_it.get_group(0) * nd_it.get_local_range(0) +
90+
sg.get_group_id()[0] * max_sg_size);
8691

87-
if ((sg_size == max_sg_size) && (base + vi_per_wi * sg_size < nelems_)) {
92+
if ((sg_size == max_sg_size) &&
93+
(base + vi_per_wi * sg_size < nelems_)) {
8894
#pragma unroll
8995
for (std::uint16_t it = 0; it < vi_per_wi; it += vec_sz) {
90-
std::size_t offset = base + static_cast<std::size_t>(it) * static_cast<std::size_t>(sg_size);
91-
auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
92-
93-
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
96+
std::size_t offset =
97+
base + static_cast<std::size_t>(it) *
98+
static_cast<std::size_t>(sg_size);
99+
auto out_multi_ptr = sycl::address_space_cast<
100+
sycl::access::address_space::global_space,
101+
sycl::access::decorated::yes>(&res_[offset]);
102+
103+
sycl::vec<DataT, vec_sz> rng_val_vec =
104+
mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
94105
sg.store<vec_sz>(out_multi_ptr, rng_val_vec);
95106
}
96107
}
97108
else {
98-
for (std::size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
99-
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
109+
for (std::size_t offset = base + sg.get_local_id()[0];
110+
offset < nelems_; offset += sg_size)
111+
{
112+
res_[offset] =
113+
mkl_rng_dev::generate_single<DistrT, EngineT>(distr,
114+
engine);
100115
}
101116
}
102117
}
103118
else {
104119
std::size_t base = nd_it.get_global_linear_id();
105120

106121
base = (base / sg_size) * sg_size * vi_per_wi + (base % sg_size);
107-
for (std::size_t offset = base; offset < std::min(nelems_, base + sg_size * vi_per_wi); offset += sg_size)
122+
for (std::size_t offset = base;
123+
offset < std::min(nelems_, base + sg_size * vi_per_wi);
124+
offset += sg_size)
108125
{
109-
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
126+
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(
127+
distr, engine);
110128
}
111129
}
112130
}
@@ -116,4 +134,4 @@ struct RngContigFunctor
116134
} // namespace rng
117135
} // namespace ext
118136
} // namespace backend
119-
} // namespace dpnp
137+
} // namespace dpnp

dpnp/backend/extensions/rng/device/dispatch/matrix.hpp

+20-8
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@
2929

3030
#include "utils/type_dispatch.hpp"
3131

32-
3332
namespace dpnp::backend::ext::rng::device::dispatch
3433
{
3534
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3635
namespace mkl_rng_dev = oneapi::mkl::rng::device;
3736

3837
template <typename Ty, typename ArgTy, typename Method, typename argMethod>
39-
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
40-
std::is_same_v<Method, argMethod>>
38+
struct TypePairDefinedEntry
39+
: std::bool_constant<std::is_same_v<Ty, ArgTy> &&
40+
std::is_same_v<Method, argMethod>>
4141
{
4242
static constexpr bool is_defined = true;
4343
};
@@ -46,11 +46,23 @@ template <typename T, typename M>
4646
struct GaussianTypePairSupportFactory
4747
{
4848
static constexpr bool is_defined = std::disjunction<
49-
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::by_default>,
50-
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::box_muller2>,
51-
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::by_default>,
52-
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::box_muller2>,
49+
TypePairDefinedEntry<T,
50+
double,
51+
M,
52+
mkl_rng_dev::gaussian_method::by_default>,
53+
TypePairDefinedEntry<T,
54+
double,
55+
M,
56+
mkl_rng_dev::gaussian_method::box_muller2>,
57+
TypePairDefinedEntry<T,
58+
float,
59+
M,
60+
mkl_rng_dev::gaussian_method::by_default>,
61+
TypePairDefinedEntry<T,
62+
float,
63+
M,
64+
mkl_rng_dev::gaussian_method::box_muller2>,
5365
// fall-through
5466
dpctl_td_ns::NotDefinedEntry>::is_defined;
5567
};
56-
} // dpnp::backend::ext::rng::device::dispatch
68+
} // namespace dpnp::backend::ext::rng::device::dispatch

dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp

+30-25
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727

2828
#include <oneapi/mkl/rng/device.hpp>
2929

30-
3130
namespace dpnp::backend::ext::rng::device::dispatch
3231
{
3332
namespace mkl_rng_dev = oneapi::mkl::rng::device;
3433

3534
template <typename funcPtrT,
36-
template <typename fnT, typename E, typename T, typename M> typename factory,
35+
template <typename fnT, typename E, typename T, typename M>
36+
typename factory,
3737
int _no_of_engines,
3838
int _no_of_types,
3939
int _no_of_methods>
@@ -44,8 +44,10 @@ class Dispatch3DTableBuilder
4444
const std::vector<funcPtrT> row_per_method() const
4545
{
4646
std::vector<funcPtrT> per_method = {
47-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get(),
48-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get(),
47+
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}
48+
.get(),
49+
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}
50+
.get(),
4951
};
5052
assert(per_method.size() == _no_of_methods);
5153
return per_method;
@@ -54,21 +56,21 @@ class Dispatch3DTableBuilder
5456
template <typename E>
5557
auto table_per_type_and_method() const
5658
{
57-
std::vector<std::vector<funcPtrT>>
58-
table_by_type = {row_per_method<E, bool>(),
59-
row_per_method<E, int8_t>(),
60-
row_per_method<E, uint8_t>(),
61-
row_per_method<E, int16_t>(),
62-
row_per_method<E, uint16_t>(),
63-
row_per_method<E, int32_t>(),
64-
row_per_method<E, uint32_t>(),
65-
row_per_method<E, int64_t>(),
66-
row_per_method<E, uint64_t>(),
67-
row_per_method<E, sycl::half>(),
68-
row_per_method<E, float>(),
69-
row_per_method<E, double>(),
70-
row_per_method<E, std::complex<float>>(),
71-
row_per_method<E, std::complex<double>>()};
59+
std::vector<std::vector<funcPtrT>> table_by_type = {
60+
row_per_method<E, bool>(),
61+
row_per_method<E, int8_t>(),
62+
row_per_method<E, uint8_t>(),
63+
row_per_method<E, int16_t>(),
64+
row_per_method<E, uint16_t>(),
65+
row_per_method<E, int32_t>(),
66+
row_per_method<E, uint32_t>(),
67+
row_per_method<E, int64_t>(),
68+
row_per_method<E, uint64_t>(),
69+
row_per_method<E, sycl::half>(),
70+
row_per_method<E, float>(),
71+
row_per_method<E, double>(),
72+
row_per_method<E, std::complex<float>>(),
73+
row_per_method<E, std::complex<double>>()};
7274
assert(table_by_type.size() == _no_of_types);
7375
return table_by_type;
7476
}
@@ -78,12 +80,15 @@ class Dispatch3DTableBuilder
7880
~Dispatch3DTableBuilder() = default;
7981

8082
template <std::uint8_t... VecSizes>
81-
void populate(funcPtrT table[][_no_of_types][_no_of_methods], std::integer_sequence<std::uint8_t, VecSizes...>) const
83+
void populate(funcPtrT table[][_no_of_types][_no_of_methods],
84+
std::integer_sequence<std::uint8_t, VecSizes...>) const
8285
{
83-
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
84-
table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>>()...,
85-
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
86-
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
86+
const auto map_by_engine = {
87+
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
88+
table_per_type_and_method<
89+
mkl_rng_dev::philox4x32x10<VecSizes>>()...,
90+
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
91+
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
8792
assert(map_by_engine.size() == _no_of_engines);
8893

8994
std::uint16_t engine_id = 0;
@@ -101,4 +106,4 @@ class Dispatch3DTableBuilder
101106
}
102107
}
103108
};
104-
} // dpnp::backend::ext::rng::device::dispatch
109+
} // namespace dpnp::backend::ext::rng::device::dispatch

0 commit comments

Comments
 (0)