Skip to content

Commit f16f5cc

Browse files
committed
Corrected offset usage
1 parent 0d84775 commit f16f5cc

File tree

3 files changed

+45
-26
lines changed

3 files changed

+45
-26
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

+14-16
Original file line numberDiff line numberDiff line change
@@ -58,55 +58,53 @@ struct RngContigFunctor
5858
EngineBuilderT engine_;
5959
DistributorBuilderT distr_;
6060
DataT * const res_ = nullptr;
61-
const size_t nelems_;
61+
const std::size_t nelems_;
6262

6363
public:
64-
RngContigFunctor(EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const size_t n_elems)
64+
RngContigFunctor(EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const std::size_t n_elems)
6565
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
6666
{
6767
}
6868

6969
void operator()(sycl::nd_item<1> nd_it) const
7070
{
71-
// auto global_id = nd_it.get_global_id();
72-
7371
auto sg = nd_it.get_sub_group();
7472
const std::uint8_t sg_size = sg.get_local_range()[0];
7573
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
7674

7775
using EngineT = typename EngineBuilderT::EngineType;
78-
// EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
79-
EngineT engine = engine_();
80-
8176
using DistrT = typename DistributorBuilderT::distr_type;
82-
DistrT distr = distr_();
8377

8478
constexpr std::size_t vec_sz = EngineT::vec_size;
79+
constexpr std::size_t vi_per_wi = vec_sz * items_per_wi;
80+
81+
EngineT engine = engine_(nd_it.get_global_id() * vi_per_wi);
82+
DistrT distr = distr_();
8583

8684
if constexpr (enable_sg_load) {
87-
const size_t base = items_per_wi * vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
85+
const std::size_t base = vi_per_wi * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
8886

89-
if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
87+
if ((sg_size == max_sg_size) && (base + vi_per_wi * sg_size < nelems_)) {
9088
#pragma unroll
91-
for (std::uint16_t it = 0; it < items_per_wi * vec_sz; it += vec_sz) {
92-
size_t offset = base + static_cast<size_t>(it) * static_cast<size_t>(sg_size);
89+
for (std::uint16_t it = 0; it < vi_per_wi; it += vec_sz) {
90+
std::size_t offset = base + static_cast<std::size_t>(it) * static_cast<std::size_t>(sg_size);
9391
auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
9492

9593
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
9694
sg.store<vec_sz>(out_multi_ptr, rng_val_vec);
9795
}
9896
}
9997
else {
100-
for (size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
98+
for (std::size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
10199
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
102100
}
103101
}
104102
}
105103
else {
106-
size_t base = nd_it.get_global_linear_id();
104+
std::size_t base = nd_it.get_global_linear_id();
107105

108-
base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
109-
for (size_t offset = base; offset < std::min(nelems_, base + sg_size * (items_per_wi * vec_sz)); offset += sg_size)
106+
base = (base / sg_size) * sg_size * vi_per_wi + (base % sg_size);
107+
for (std::size_t offset = base; offset < std::min(nelems_, base + sg_size * vi_per_wi); offset += sg_size)
110108
{
111109
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
112110
}

dpnp/backend/extensions/rng/device/engine/base_engine.hpp

+25-4
Original file line numberDiff line numberDiff line change
@@ -56,35 +56,53 @@ class EngineType {
5656
};
5757

5858
// A total number of supported engines == EngineType::Base
59-
constexpr int no_of_engines = EngineType::base_id();
59+
constexpr std::uint8_t no_of_engines = EngineType::base_id();
6060

6161
class EngineBase {
6262
private:
6363
sycl::queue q_{};
6464
std::vector<std::uint64_t> seed_vec{};
6565
std::vector<std::uint64_t> offset_vec{};
6666

67+
void validate_vec_size(const std::size_t size) {
68+
if (size > max_vec_n) {
69+
throw std::runtime_error("TODO: add text");
70+
}
71+
}
72+
6773
public:
6874
EngineBase() {}
6975

7076
EngineBase(sycl::queue &q, std::uint64_t seed, std::uint64_t offset) :
7177
q_(q), seed_vec(1, seed), offset_vec(1, offset) {}
7278

7379
EngineBase(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::uint64_t offset) :
74-
q_(q), seed_vec(seeds), offset_vec(1, offset) {}
80+
q_(q), seed_vec(seeds), offset_vec(1, offset) {
81+
validate_vec_size(seeds.size());
82+
}
7583

7684
EngineBase(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::uint64_t offset) : q_(q), offset_vec(1, offset) {
85+
validate_vec_size(seeds.size());
86+
7787
seed_vec.reserve(seeds.size());
7888
seed_vec.assign(seeds.begin(), seeds.end());
7989
}
8090

8191
EngineBase(sycl::queue &q, std::uint64_t seed, std::vector<std::uint64_t> &offsets) :
82-
q_(q), seed_vec(1, seed), offset_vec(offsets) {}
92+
q_(q), seed_vec(1, seed), offset_vec(offsets) {
93+
validate_vec_size(offsets.size());
94+
}
8395

8496
EngineBase(sycl::queue &q, std::vector<std::uint64_t> &seeds, std::vector<std::uint64_t> &offsets) :
85-
q_(q), seed_vec(seeds), offset_vec(offsets) {}
97+
q_(q), seed_vec(seeds), offset_vec(offsets) {
98+
validate_vec_size(seeds.size());
99+
validate_vec_size(offsets.size());
100+
}
86101

87102
EngineBase(sycl::queue &q, std::vector<std::uint32_t> &seeds, std::vector<std::uint64_t> &offsets) : q_(q), offset_vec(offsets) {
103+
validate_vec_size(seeds.size());
104+
validate_vec_size(offsets.size());
105+
88106
seed_vec.reserve(seeds.size());
89107
seed_vec.assign(seeds.begin(), seeds.end());
90108
}
@@ -106,5 +124,8 @@ class EngineBase {
106124
std::vector<std::uint64_t>& get_offsets() noexcept {
107125
return offset_vec;
108126
}
127+
128+
//
129+
static constexpr std::uint8_t max_vec_n = 1;
109130
};
110131
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/builder/base_builder.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace dpnp::backend::ext::rng::device::engine::builder
3737
template <typename EngineT, typename SeedT, typename OffsetT>
3838
class BaseBuilder {
3939
private:
40-
static constexpr std::uint8_t max_n = 10;
40+
static constexpr std::uint8_t max_n = EngineBase::max_vec_n;
4141

4242
std::uint8_t no_of_seeds;
4343
std::uint8_t no_of_offsets;
@@ -75,7 +75,7 @@ class BaseBuilder {
7575
{
7676
switch (no_of_seeds) {
7777
case 1: {
78-
if constexpr (std::is_same_v<EngineT, mkl_rng_dev::mcg59<8>>) {
78+
if constexpr (std::is_same_v<EngineT, mkl_rng_dev::mcg59<EngineT::vec_size>>) {
7979
// issue with mcg59<>() constructor which breaks compilation
8080
return EngineT(seeds[0], offsets[0]);
8181
}
@@ -90,16 +90,16 @@ class BaseBuilder {
9090
return EngineT();
9191
}
9292

93-
inline auto operator()(OffsetT offset) const
93+
inline auto operator()(const OffsetT offset) const
9494
{
9595
switch (no_of_seeds) {
9696
case 1: {
97-
if constexpr (std::is_same_v<EngineT, mkl_rng_dev::mcg59<8>>) {
97+
if constexpr (std::is_same_v<EngineT, mkl_rng_dev::mcg59<EngineT::vec_size>>) {
9898
// issue with mcg59<>() constructor which breaks compilation
99-
return EngineT(seeds[0], offsets[0]);
99+
return EngineT(seeds[0], offsets[0] + offset);
100100
}
101101
else {
102-
return EngineT({seeds[0]}, {offset});
102+
return EngineT({seeds[0]}, {offsets[0] + offset});
103103
}
104104
}
105105
// TODO: implement full switch

0 commit comments

Comments
 (0)