Skip to content

Commit 4401ebe

Browse files
committed
handle empty matrices in broadcast
1 parent da3174e commit 4401ebe

File tree

1 file changed

+46
-42
lines changed

1 file changed

+46
-42
lines changed

src/broadcast.cpp

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,60 @@
1-
#include <iostream>
21
#include <omp.h>
3-
#include <vector>
2+
3+
#include <iostream>
44
#include <stdexcept>
5+
#include <vector>
56

67
// Function to broadcast two matrices
7-
void broadcast(const std::vector<std::vector<int>>& A, std::vector<std::vector<int>>& B) {
8-
size_t rowsA = A.size();
9-
size_t colsA = A[0].size();
10-
size_t rowsB = B.size();
11-
size_t colsB = B[0].size();
12-
13-
if (rowsA != rowsB && rowsB != 1) {
14-
throw std::invalid_argument("Incompatible dimensions for broadcasting");
15-
}
16-
if (colsA != colsB && colsB != 1) {
17-
throw std::invalid_argument("Incompatible dimensions for broadcasting");
18-
}
8+
void broadcast(const std::vector<std::vector<int>>& A,
9+
std::vector<std::vector<int>>& B) {
10+
size_t rowsA = A.size();
11+
size_t colsA = A[0].size();
12+
size_t rowsB = B.size();
13+
size_t colsB = B[0].size();
1914

20-
if (rowsB == 1) {
21-
B.resize(rowsA, B[0]);
22-
}
23-
if (colsB == 1) {
24-
for (auto& row : B) {
25-
row.resize(colsA, row[0]);
26-
}
15+
if (rowsA == 0 || colsA == 0 || rowsB == 0 || colsB == 0) {
16+
throw std::invalid_argument("Empty matrix cannot be broadcasted");
17+
}
18+
if (rowsA != rowsB && rowsB != 1) {
19+
throw std::invalid_argument("Incompatible dimensions for broadcasting");
20+
}
21+
if (colsA != colsB && colsB != 1) {
22+
throw std::invalid_argument("Incompatible dimensions for broadcasting");
23+
}
24+
25+
if (rowsB == 1) {
26+
B.resize(rowsA, B[0]);
27+
}
28+
if (colsB == 1) {
29+
for (auto& row : B) {
30+
row.resize(colsA, row[0]);
2731
}
32+
}
2833

29-
30-
#pragma omp parallel for
31-
for (size_t i = 0; i < rowsA; i++) {
32-
#pragma omp parallel for
33-
for (size_t j = 0; j < colsA; j++) {
34-
B[i][j] = A[i][j]+B[i][j];
35-
}
34+
#pragma omp parallel for
35+
for (size_t i = 0; i < rowsA; i++) {
36+
#pragma omp parallel for
37+
for (size_t j = 0; j < colsA; j++) {
38+
B[i][j] = A[i][j] + B[i][j];
3639
}
40+
}
3741
}
3842

3943
int main() {
40-
std::vector<std::vector<int>> A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
41-
std::vector<std::vector<int>> B = {{1,2,3}}; // B has only one row
42-
43-
try {
44-
broadcast(A, B);
45-
for (const auto& row : B) {
46-
for (const auto& elem : row) {
47-
std::cout << elem << " ";
48-
}
49-
std::cout << std::endl;
50-
}
51-
} catch (const std::invalid_argument& e) {
52-
std::cerr << "Error: " << e.what() << std::endl;
44+
std::vector<std::vector<int>> A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
45+
std::vector<std::vector<int>> B = {{1, 2, 3}}; // B has only one row
46+
47+
try {
48+
broadcast(A, B);
49+
for (const auto& row : B) {
50+
for (const auto& elem : row) {
51+
std::cout << elem << " ";
52+
}
53+
std::cout << std::endl;
5354
}
55+
} catch (const std::invalid_argument& e) {
56+
std::cerr << "Error: " << e.what() << std::endl;
57+
}
5458

55-
return 0;
59+
return 0;
5660
}

0 commit comments

Comments
 (0)