Skip to content

Commit 7e874d8

Browse files
committed
Add sparse matrix usage to example
1 parent c267492 commit 7e874d8

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

scripts/build_run.sh

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ mpirun -n 3 ./build/src/example3
1616
mpirun -n 3 ./build/src/example4
1717
mpirun -n 3 ./build/src/example5
1818
mpirun -n 3 ./build/src/example6
19+
mpirun -n 3 ./build/src/example7

src/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ add_executable(example6 example6.cpp)
4141

4242
target_compile_definitions(example6 INTERFACE DR_FORMAT)
4343
target_link_libraries(example6 DR::mpi fmt::fmt)
44+
45+
add_executable(example7 example7.cpp)
46+
47+
target_compile_definitions(example7 INTERFACE DR_FORMAT)
48+
target_link_libraries(example7 DR::mpi fmt::fmt)

src/example7.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// SPDX-FileCopyrightText: Intel Corporation
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <dr/mp.hpp>
6+
#include <fmt/core.h>
7+
8+
9+
/* Sparse band matrix vector multiplication */
10+
int main() {
11+
dr::mp::init(sycl::default_selector_v);
12+
using I = long;
13+
using V = double;
14+
dr::views::csr_matrix_view<V, I> local_data;
15+
auto root = 0;
16+
auto n = 10;
17+
auto up = 1; // number of diagonals above main diagonal
18+
auto down = up; // number of diagonals below main diagonal
19+
if (root == dr::mp::rank()) {
20+
local_data = dr::generate_band_csr<V, I>(n, up, down);
21+
}
22+
23+
dr::mp::distributed_sparse_matrix<
24+
V, I, dr::mp::MpiBackend,
25+
dr::mp::csr_eq_distribution<V, I, dr::mp::MpiBackend>>
26+
matrix(local_data, root);
27+
28+
std::vector<double> b;
29+
b.reserve(matrix.shape().second);
30+
std::vector<double> res(matrix.shape().first);
31+
for (auto i = 0; i < matrix.shape().second; i++) {
32+
b.push_back(i);
33+
}
34+
35+
dr::mp::broadcasted_vector<double> broadcasted_b;
36+
broadcasted_b.broadcast_data(matrix.shape().second, 0, b,
37+
dr::mp::default_comm());
38+
39+
gemv(root, res, matrix, broadcasted_b);
40+
41+
if (root == dr::mp::rank()) {
42+
fmt::print("Band matrix {} x {} with bandwitch {}\n", n, n, up * 2);
43+
fmt::print("Input: ");
44+
for (auto x: b) {
45+
fmt::print("{} ", x);
46+
}
47+
fmt::print("\n");
48+
fmt::print("Matrix vector multiplication res: ");
49+
for (auto x: res) {
50+
fmt::print("{} ", x);
51+
}
52+
fmt::print("\n");
53+
54+
}
55+
56+
dr::mp::finalize();
57+
58+
return 0;
59+
}

0 commit comments

Comments
 (0)