Skip to content

Commit 3b6a339

Browse files
ENH: kernels for random.vonmisses; part 2 (#779)
1 parent 2ae0153 commit 3b6a339

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

+32-28
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ void dpnp_rng_uniform_c(void* result, const long low, const long high, const siz
12441244
template <typename _DataType>
12451245
void dpnp_rng_vonmises_large_kappa_c(void* result, const _DataType mu, const _DataType kappa, const size_t size)
12461246
{
1247-
if (!size)
1247+
if (!size || !result)
12481248
{
12491249
return;
12501250
}
@@ -1314,20 +1314,23 @@ void dpnp_rng_vonmises_large_kappa_c(void* result, const _DataType mu, const _Da
13141314
dpnp_memory_free_c(Uvec);
13151315

13161316
mkl_rng::uniform<_DataType> uniform_distribution(d_zero, d_one);
1317-
auto event_out = mkl_rng::generate(uniform_distribution, DPNP_RNG_ENGINE, size, Vvec);
1318-
event_out.wait();
1317+
auto uniform_distr_event = mkl_rng::generate(uniform_distribution, DPNP_RNG_ENGINE, size, Vvec);
13191318

1320-
// TODO
1321-
// kernel
1322-
for (size_t i = 0; i < size; i++)
1323-
{
1324-
_DataType mod, resi;
1319+
cl::sycl::range<1> gws(size);
13251320

1326-
resi = (Vvec[i] < 0.5) ? mu - result1[i] : mu + result1[i];
1327-
mod = fabs(resi);
1328-
mod = (fmod(mod + M_PI, 2 * M_PI) - M_PI);
1329-
result1[i] = (resi < 0) ? -mod : mod;
1330-
}
1321+
auto paral_kernel_acceptance = [&](cl::sycl::handler& cgh) {
1322+
cgh.depends_on({uniform_distr_event});
1323+
cgh.parallel_for(gws, [=](cl::sycl::id<1> global_id) {
1324+
size_t i = global_id[0];
1325+
double mod, resi;
1326+
resi = (Vvec[i] < 0.5) ? mu - result1[i] : mu + result1[i];
1327+
mod = cl::sycl::fabs(resi);
1328+
mod = (cl::sycl::fmod(mod + M_PI, 2 * M_PI) - M_PI);
1329+
result1[i] = (resi < 0) ? -mod : mod;
1330+
});
1331+
};
1332+
auto acceptance_event = DPNP_QUEUE.submit(paral_kernel_acceptance);
1333+
acceptance_event.wait();
13311334

13321335
dpnp_memory_free_c(Vvec);
13331336
return;
@@ -1336,7 +1339,7 @@ void dpnp_rng_vonmises_large_kappa_c(void* result, const _DataType mu, const _Da
13361339
template <typename _DataType>
13371340
void dpnp_rng_vonmises_small_kappa_c(void* result, const _DataType mu, const _DataType kappa, const size_t size)
13381341
{
1339-
if (!size)
1342+
if (!size || !result)
13401343
{
13411344
return;
13421345
}
@@ -1391,20 +1394,22 @@ void dpnp_rng_vonmises_small_kappa_c(void* result, const _DataType mu, const _Da
13911394
dpnp_memory_free_c(Uvec);
13921395

13931396
mkl_rng::uniform<_DataType> uniform_distribution(d_zero, d_one);
1394-
auto event_out = mkl_rng::generate(uniform_distribution, DPNP_RNG_ENGINE, size, Vvec);
1395-
event_out.wait();
1397+
auto uniform_distr_event = mkl_rng::generate(uniform_distribution, DPNP_RNG_ENGINE, size, Vvec);
13961398

1397-
// TODO
1398-
// kernel
1399-
for (size_t i = 0; i < size; i++)
1400-
{
1401-
double mod, resi;
1402-
1403-
resi = (Vvec[i] < 0.5) ? mu - result1[i] : mu + result1[i];
1404-
mod = fabs(resi);
1405-
mod = (fmod(mod + M_PI, 2 * M_PI) - M_PI);
1406-
result1[i] = (resi < 0) ? -mod : mod;
1407-
}
1399+
cl::sycl::range<1> gws(size);
1400+
auto paral_kernel_acceptance = [&](cl::sycl::handler& cgh) {
1401+
cgh.depends_on({uniform_distr_event});
1402+
cgh.parallel_for(gws, [=](cl::sycl::id<1> global_id) {
1403+
size_t i = global_id[0];
1404+
double mod, resi;
1405+
resi = (Vvec[i] < 0.5) ? mu - result1[i] : mu + result1[i];
1406+
mod = cl::sycl::fabs(resi);
1407+
mod = (cl::sycl::fmod(mod + M_PI, 2 * M_PI) - M_PI);
1408+
result1[i] = (resi < 0) ? -mod : mod;
1409+
});
1410+
};
1411+
auto acceptance_event = DPNP_QUEUE.submit(paral_kernel_acceptance);
1412+
acceptance_event.wait();
14081413

14091414
dpnp_memory_free_c(Vvec);
14101415
return;
@@ -1423,7 +1428,6 @@ void dpnp_rng_vonmises_c(void* result, const _DataType mu, const _DataType kappa
14231428
dpnp_rng_vonmises_large_kappa_c<_DataType>(result, mu, kappa, size);
14241429
else
14251430
dpnp_rng_vonmises_small_kappa_c<_DataType>(result, mu, kappa, size);
1426-
// TODO case when kappa < kappa < 1e-8 (very small)
14271431
}
14281432

14291433
template <typename _KernelNameSpecialization>

0 commit comments

Comments
 (0)