@@ -58,55 +58,53 @@ struct RngContigFunctor
58
58
EngineBuilderT engine_;
59
59
DistributorBuilderT distr_;
60
60
DataT * const res_ = nullptr ;
61
- const size_t nelems_;
61
+ const std:: size_t nelems_;
62
62
63
63
public:
64
- RngContigFunctor (EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const size_t n_elems)
64
+ RngContigFunctor (EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const std:: size_t n_elems)
65
65
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
66
66
{
67
67
}
68
68
69
69
void operator ()(sycl::nd_item<1 > nd_it) const
70
70
{
71
- // auto global_id = nd_it.get_global_id();
72
-
73
71
auto sg = nd_it.get_sub_group ();
74
72
const std::uint8_t sg_size = sg.get_local_range ()[0 ];
75
73
const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
76
74
77
75
using EngineT = typename EngineBuilderT::EngineType;
78
- // EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
79
- EngineT engine = engine_ ();
80
-
81
76
using DistrT = typename DistributorBuilderT::distr_type;
82
- DistrT distr = distr_ ();
83
77
84
78
constexpr std::size_t vec_sz = EngineT::vec_size;
79
+ constexpr std::size_t vi_per_wi = vec_sz * items_per_wi;
80
+
81
+ EngineT engine = engine_ (nd_it.get_global_id () * vi_per_wi);
82
+ DistrT distr = distr_ ();
85
83
86
84
if constexpr (enable_sg_load) {
87
- const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
85
+ const std:: size_t base = vi_per_wi * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
88
86
89
- if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
87
+ if ((sg_size == max_sg_size) && (base + vi_per_wi * sg_size < nelems_)) {
90
88
#pragma unroll
91
- for (std::uint16_t it = 0 ; it < items_per_wi * vec_sz ; it += vec_sz) {
92
- size_t offset = base + static_cast <size_t >(it) * static_cast <size_t >(sg_size);
89
+ for (std::uint16_t it = 0 ; it < vi_per_wi ; it += vec_sz) {
90
+ std:: size_t offset = base + static_cast <std:: size_t >(it) * static_cast <std:: size_t >(sg_size);
93
91
auto out_multi_ptr = sycl::address_space_cast<sycl::access ::address_space::global_space, sycl::access ::decorated::yes>(&res_[offset]);
94
92
95
93
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
96
94
sg.store <vec_sz>(out_multi_ptr, rng_val_vec);
97
95
}
98
96
}
99
97
else {
100
- for (size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
98
+ for (std:: size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
101
99
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
102
100
}
103
101
}
104
102
}
105
103
else {
106
- size_t base = nd_it.get_global_linear_id ();
104
+ std:: size_t base = nd_it.get_global_linear_id ();
107
105
108
- base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
109
- for (size_t offset = base; offset < std::min (nelems_, base + sg_size * (items_per_wi * vec_sz) ); offset += sg_size)
106
+ base = (base / sg_size) * sg_size * vi_per_wi + (base % sg_size);
107
+ for (std:: size_t offset = base; offset < std::min (nelems_, base + sg_size * vi_per_wi ); offset += sg_size)
110
108
{
111
109
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
112
110
}
0 commit comments