5
5
#include < dr/mp.hpp>
6
6
#include < fmt/core.h>
7
7
8
-
9
8
/* Sparse band matrix vector multiplication */
10
9
int main () {
11
10
dr::mp::init (sycl::default_selector_v);
@@ -14,10 +13,10 @@ int main() {
14
13
dr::views::csr_matrix_view<V, I> local_data;
15
14
auto root = 0 ;
16
15
auto n = 10 ;
17
- auto up = 1 ; // number of diagonals above main diagonal
16
+ auto up = 1 ; // number of diagonals above main diagonal
18
17
auto down = up; // number of diagonals below main diagonal
19
18
if (root == dr::mp::rank ()) {
20
- local_data = dr::generate_band_csr<V, I>(n, up, down);
19
+ local_data = dr::generate_band_csr<V, I>(n, up, down);
21
20
}
22
21
23
22
dr::mp::distributed_sparse_matrix<
@@ -34,23 +33,22 @@ int main() {
34
33
35
34
dr::mp::broadcasted_vector<double > broadcasted_b;
36
35
broadcasted_b.broadcast_data (matrix.shape ().second , 0 , b,
37
- dr::mp::default_comm ());
36
+ dr::mp::default_comm ());
38
37
39
38
gemv (root, res, matrix, broadcasted_b);
40
39
41
40
if (root == dr::mp::rank ()) {
42
41
fmt::print (" Band matrix {} x {} with bandwitch {}\n " , n, n, up * 2 );
43
42
fmt::print (" Input: " );
44
- for (auto x: b) {
45
- fmt::print (" {} " , x);
43
+ for (auto x : b) {
44
+ fmt::print (" {} " , x);
46
45
}
47
46
fmt::print (" \n " );
48
47
fmt::print (" Matrix vector multiplication res: " );
49
- for (auto x: res) {
50
- fmt::print (" {} " , x);
48
+ for (auto x : res) {
49
+ fmt::print (" {} " , x);
51
50
}
52
51
fmt::print (" \n " );
53
-
54
52
}
55
53
56
54
dr::mp::finalize ();
0 commit comments