Skip to content

Commit 0d84775

Browse files
committed
Reworked engines classes whech bind with python
1 parent a8f20e3 commit 0d84775

File tree

9 files changed

+102
-143
lines changed

9 files changed

+102
-143
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
#include <sycl/sycl.hpp>
3131
#include <oneapi/mkl/rng/device.hpp>
3232

33-
// dpctl tensor headers
34-
// #include "utils/offset_utils.hpp"
35-
3633
namespace dpnp
3734
{
3835
namespace backend
@@ -71,14 +68,15 @@ struct RngContigFunctor
7168

7269
void operator()(sycl::nd_item<1> nd_it) const
7370
{
74-
auto global_id = nd_it.get_global_id();
71+
// auto global_id = nd_it.get_global_id();
7572

7673
auto sg = nd_it.get_sub_group();
7774
const std::uint8_t sg_size = sg.get_local_range()[0];
7875
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
7976

8077
using EngineT = typename EngineBuilderT::EngineType;
81-
EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
78+
// EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
79+
EngineT engine = engine_();
8280

8381
using DistrT = typename DistributorBuilderT::distr_type;
8482
DistrT distr = distr_();

dpnp/backend/extensions/rng/device/engine/base_engine.hpp

+37-5
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,52 @@ class EngineType {
5959
constexpr int no_of_engines = EngineType::base_id();
6060

6161
class EngineBase {
62+
private:
63+
sycl::queue q_{};
64+
std::vector<std::uint64_t> seed_vec{};
65+
std::vector<std::uint64_t> offset_vec{};
66+
6267
public:
68+
EngineBase() {}
69+
70+
EngineBase(sycl::queue &q, std::uint64_t seed, std::uint64_t offset) :
71+
q_(q), seed_vec(1, seed), offset_vec(1, offset) {}
72+
73+
EngineBase(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::uint64_t offset) :
74+
q_(q), seed_vec(seeds), offset_vec(1, offset) {}
75+
76+
EngineBase(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::uint64_t offset) : q_(q), offset_vec(1, offset) {
77+
seed_vec.reserve(seeds.size());
78+
seed_vec.assign(seeds.begin(), seeds.end());
79+
}
80+
81+
EngineBase(sycl::queue &q, std::uint64_t seed, std::vector<std::uint64_t> &offsets) :
82+
q_(q), seed_vec(1, seed), offset_vec(offsets) {}
83+
84+
EngineBase(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::vector<std::uint64_t> &offsets) :
85+
q_(q), seed_vec(seeds), offset_vec(offsets) {}
86+
87+
EngineBase(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::vector<std::uint64_t> &offsets) : q_(q), offset_vec(offsets) {
88+
seed_vec.reserve(seeds.size());
89+
seed_vec.assign(seeds.begin(), seeds.end());
90+
}
91+
6392
virtual ~EngineBase() {}
64-
virtual sycl::queue &get_queue() = 0;
6593

6694
virtual EngineType get_type() const noexcept {
6795
return EngineType::Base;
6896
}
6997

70-
virtual std::vector<std::uint64_t> get_seeds() const noexcept {
71-
return std::vector<std::uint64_t>();
98+
sycl::queue &get_queue() noexcept {
99+
return q_;
100+
}
101+
102+
std::vector<std::uint64_t>& get_seeds() noexcept {
103+
return seed_vec;
72104
}
73105

74-
virtual std::vector<std::uint64_t> get_offsets() const noexcept {
75-
return std::vector<std::uint64_t>();
106+
std::vector<std::uint64_t>& get_offsets() noexcept {
107+
return offset_vec;
76108
}
77109
};
78110
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/mcg31m1_engine.hpp

+4-20
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,15 @@
3131
namespace dpnp::backend::ext::rng::device::engine
3232
{
3333
class MCG31M1 : public EngineBase {
34-
private:
35-
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec{};
37-
std::vector<std::uint64_t> offset_vec{};
38-
3934
public:
40-
MCG31M1(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {
41-
seed_vec.push_back(seed);
42-
offset_vec.push_back(offset);
43-
}
35+
MCG31M1(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) :
36+
EngineBase(q, seed, offset) {}
4437

45-
sycl::queue &get_queue() override {
46-
return q_;
47-
}
38+
MCG31M1(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::uint64_t offset = 0) :
39+
EngineBase(q, seeds, offset) {}
4840

4941
virtual EngineType get_type() const noexcept override {
5042
return EngineType::MCG31M1;
5143
}
52-
53-
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54-
return seed_vec;
55-
}
56-
57-
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58-
return offset_vec;
59-
}
6044
};
6145
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/mcg59_engine.hpp

+4-20
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,15 @@
3131
namespace dpnp::backend::ext::rng::device::engine
3232
{
3333
class MCG59 : public EngineBase {
34-
private:
35-
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec{};
37-
std::vector<std::uint64_t> offset_vec{};
38-
3934
public:
40-
MCG59(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {
41-
seed_vec.push_back(seed);
42-
offset_vec.push_back(offset);
43-
}
35+
MCG59(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) :
36+
EngineBase(q, seed, offset) {}
4437

45-
sycl::queue &get_queue() override {
46-
return q_;
47-
}
38+
MCG59(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::uint64_t offset = 0) :
39+
EngineBase(q, seeds, offset) {}
4840

4941
virtual EngineType get_type() const noexcept override {
5042
return EngineType::MCG59;
5143
}
52-
53-
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54-
return seed_vec;
55-
}
56-
57-
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58-
return offset_vec;
59-
}
6044
};
6145
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/mrg32k3a_engine.hpp

+10-20
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,21 @@
3131
namespace dpnp::backend::ext::rng::device::engine
3232
{
3333
class MRG32k3a : public EngineBase {
34-
private:
35-
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec{};
37-
std::vector<std::uint64_t> offset_vec{};
38-
3934
public:
40-
MRG32k3a(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {
41-
seed_vec.push_back(seed);
42-
offset_vec.push_back(offset);
43-
}
35+
MRG32k3a(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) :
36+
EngineBase(q, seed, offset) {}
4437

45-
sycl::queue &get_queue() override {
46-
return q_;
47-
}
38+
MRG32k3a(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::uint64_t offset = 0) :
39+
EngineBase(q, seeds, offset) {}
4840

49-
virtual EngineType get_type() const noexcept override {
50-
return EngineType::MRG32k3a;
51-
}
41+
MRG32k3a(sycl::queue &q, std::uint32_t seed, std::vector<std::uint64_t> &offsets) :
42+
EngineBase(q, seed, offsets) {}
5243

53-
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54-
return seed_vec;
55-
}
44+
MRG32k3a(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::vector<std::uint64_t> &offsets) :
45+
EngineBase(q, seeds, offsets) {}
5646

57-
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58-
return offset_vec;
47+
virtual EngineType get_type() const noexcept override {
48+
return EngineType::MRG32k3a;
5949
}
6050
};
6151
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/philox4x32x10_engine.hpp

+10-20
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,21 @@
3131
namespace dpnp::backend::ext::rng::device::engine
3232
{
3333
class PHILOX4x32x10 : public EngineBase {
34-
private:
35-
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec{};
37-
std::vector<std::uint64_t> offset_vec{};
38-
3934
public:
40-
PHILOX4x32x10(sycl::queue &q, std::uint64_t seed, std::uint64_t offset = 0) : q_(q) {
41-
seed_vec.push_back(seed);
42-
offset_vec.push_back(offset);
43-
}
35+
PHILOX4x32x10(sycl::queue &q, std::uint64_t seed, std::uint64_t offset = 0) :
36+
EngineBase(q, seed, offset) {}
4437

45-
sycl::queue &get_queue() override {
46-
return q_;
47-
}
38+
PHILOX4x32x10(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::uint64_t offset = 0) :
39+
EngineBase(q, seeds, offset) {}
4840

49-
virtual EngineType get_type() const noexcept override {
50-
return EngineType::PHILOX4x32x10;
51-
}
41+
PHILOX4x32x10(sycl::queue &q, std::uint64_t seed, std::vector<std::uint64_t> &offsets) :
42+
EngineBase(q, seed, offsets) {}
5243

53-
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54-
return seed_vec;
55-
}
44+
PHILOX4x32x10(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::vector<std::uint64_t> &offsets) :
45+
EngineBase(q, seeds, offsets) {}
5646

57-
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58-
return offset_vec;
47+
virtual EngineType get_type() const noexcept override {
48+
return EngineType::PHILOX4x32x10;
5949
}
6050
};
6151
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/gaussian.cpp

+6-22
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,26 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29-
// #include "utils/memory_overlap.hpp"
3029
#include "utils/type_dispatch.hpp"
3130
#include "utils/type_utils.hpp"
32-
33-
// dpctl tensor headers
3431
#include "kernels/alignment.hpp"
3532

36-
#include "common_impl.hpp"
3733
#include "gaussian.hpp"
34+
#include "common_impl.hpp"
3835

39-
#include "engine/base_engine.hpp"
4036
#include "engine/builder/builder.hpp"
4137

4238
#include "dispatch/matrix.hpp"
4339
#include "dispatch/table_builder.hpp"
4440

4541

46-
namespace dpnp
47-
{
48-
namespace backend
49-
{
50-
namespace ext
51-
{
52-
namespace rng
53-
{
54-
namespace device
42+
namespace dpnp::backend::ext::rng::device
5543
{
5644
namespace dpctl_krn_ns = dpctl::tensor::kernels::alignment_utils;
5745
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
46+
namespace dpctl_tu_ns = dpctl::tensor::type_utils;
5847
namespace mkl_rng_dev = oneapi::mkl::rng::device;
5948
namespace py = pybind11;
60-
namespace type_utils = dpctl::tensor::type_utils;
6149

6250
using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
6351
using dpctl_krn_ns::is_aligned;
@@ -109,7 +97,7 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
10997
const std::vector<sycl::event> &depends)
11098
{
11199
auto &exec_q = engine->get_queue();
112-
type_utils::validate_type_for_device<DataT>(exec_q);
100+
dpctl_tu_ns::validate_type_for_device<DataT>(exec_q);
113101

114102
DataT *out = reinterpret_cast<DataT *>(out_ptr);
115103
DataT mean = static_cast<DataT>(mean_val);
@@ -242,13 +230,9 @@ struct GaussianContigFactory
242230
}
243231
};
244232

245-
void init_gaussian_dispatch_table(void)
233+
void init_gaussian_dispatch_3d_table(void)
246234
{
247235
dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
248236
contig.populate(gaussian_dispatch_table);
249237
}
250-
} // namespace device
251-
} // namespace rng
252-
} // namespace ext
253-
} // namespace backend
254-
} // namespace dpnp
238+
} // dpnp::backend::ext::rng::device

dpnp/backend/extensions/rng/device/gaussian.hpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525

2626
#pragma once
2727

28-
#include <sycl/sycl.hpp>
29-
#include <oneapi/mkl.hpp>
30-
#include <oneapi/mkl/rng/device.hpp>
31-
3228
#include <dpctl4pybind11.hpp>
3329

3430
#include "engine/base_engine.hpp"
@@ -44,5 +40,5 @@ extern std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
4440
dpctl::tensor::usm_ndarray res,
4541
const std::vector<sycl::event> &depends = {});
4642

47-
extern void init_gaussian_dispatch_table(void);
43+
extern void init_gaussian_dispatch_3d_table(void);
4844
} // namespace dpnp::backend::ext::rng::device

0 commit comments

Comments
 (0)