@@ -40,37 +40,34 @@ template <typename funcPtrT,
40
40
class Dispatch3DTableBuilder
41
41
{
42
42
private:
43
- template <typename E, typename T>
43
+ template <typename E, typename T, typename ... Methods >
44
44
const std::vector<funcPtrT> row_per_method () const
45
45
{
46
46
std::vector<funcPtrT> per_method = {
47
- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}
48
- .get (),
49
- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}
50
- .get (),
47
+ factory<funcPtrT, E, T, Methods>{}.get ()...,
51
48
};
52
49
assert (per_method.size () == _no_of_methods);
53
50
return per_method;
54
51
}
55
52
56
- template <typename E>
53
+ template <typename E, typename ... Methods >
57
54
auto table_per_type_and_method () const
58
55
{
59
56
std::vector<std::vector<funcPtrT>> table_by_type = {
60
- row_per_method<E, bool >(),
61
- row_per_method<E, int8_t >(),
62
- row_per_method<E, uint8_t >(),
63
- row_per_method<E, int16_t >(),
64
- row_per_method<E, uint16_t >(),
65
- row_per_method<E, int32_t >(),
66
- row_per_method<E, uint32_t >(),
67
- row_per_method<E, int64_t >(),
68
- row_per_method<E, uint64_t >(),
69
- row_per_method<E, sycl::half>(),
70
- row_per_method<E, float >(),
71
- row_per_method<E, double >(),
72
- row_per_method<E, std::complex<float >>(),
73
- row_per_method<E, std::complex<double >>()};
57
+ row_per_method<E, bool , Methods... >(),
58
+ row_per_method<E, int8_t , Methods... >(),
59
+ row_per_method<E, uint8_t , Methods... >(),
60
+ row_per_method<E, int16_t , Methods... >(),
61
+ row_per_method<E, uint16_t , Methods... >(),
62
+ row_per_method<E, int32_t , Methods... >(),
63
+ row_per_method<E, uint32_t , Methods... >(),
64
+ row_per_method<E, int64_t , Methods... >(),
65
+ row_per_method<E, uint64_t , Methods... >(),
66
+ row_per_method<E, sycl::half, Methods... >(),
67
+ row_per_method<E, float , Methods... >(),
68
+ row_per_method<E, double , Methods... >(),
69
+ row_per_method<E, std::complex<float >, Methods... >(),
70
+ row_per_method<E, std::complex<double >, Methods... >()};
74
71
assert (table_by_type.size () == _no_of_types);
75
72
return table_by_type;
76
73
}
@@ -79,16 +76,15 @@ class Dispatch3DTableBuilder
79
76
Dispatch3DTableBuilder () = default ;
80
77
~Dispatch3DTableBuilder () = default ;
81
78
82
- template <std::uint8_t ... VecSizes>
79
+ template <typename ... Methods, std::uint8_t ... VecSizes>
83
80
void populate (funcPtrT table[][_no_of_types][_no_of_methods],
84
81
std::integer_sequence<std::uint8_t , VecSizes...>) const
85
82
{
86
83
const auto map_by_engine = {
87
- table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
88
- table_per_type_and_method<
89
- mkl_rng_dev::philox4x32x10<VecSizes>>()...,
90
- table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
91
- table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
84
+ table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>, Methods...>()...,
85
+ table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>, Methods...>()...,
86
+ table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>, Methods...>()...,
87
+ table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>, Methods...>()...};
92
88
assert (map_by_engine.size () == _no_of_engines);
93
89
94
90
std::uint16_t engine_id = 0 ;
0 commit comments