Skip to content

Commit 13e2a14

Browse files
committed
wrap flat matrices
1 parent 87a66e6 commit 13e2a14

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

ortools/util/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ cc_library(
358358
name = "vector_or_function",
359359
hdrs = ["vector_or_function.h"],
360360
deps = [
361+
":flat_matrix",
361362
"//ortools/base",
362363
],
363364
)
@@ -405,6 +406,7 @@ cc_library(
405406
cc_library(
406407
name = "random_engine",
407408
hdrs = ["random_engine.h"],
409+
deps = [],
408410
)
409411

410412
cc_library(

ortools/util/vector_or_function.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <vector>
1919

2020
#include "ortools/base/logging.h"
21+
#include "ortools/util/flat_matrix.h"
2122

2223
namespace operations_research {
2324

@@ -90,6 +91,20 @@ class MatrixOrFunction<ScalarType, std::vector<std::vector<ScalarType>>,
9091
std::vector<std::vector<ScalarType>> matrix_;
9192
};
9293

94+
// Specialization for FlatMatrix<>, which is faster than vector<vector<>>.
95+
template <typename ScalarType, bool square>
96+
class MatrixOrFunction<ScalarType, FlatMatrix<ScalarType>, square> {
97+
public:
98+
explicit MatrixOrFunction(FlatMatrix<ScalarType> matrix)
99+
: matrix_(std::move(matrix)) {}
100+
void Reset(FlatMatrix<ScalarType> matrix) { matrix_ = std::move(matrix); }
101+
ScalarType operator()(int i, int j) const { return matrix_[i][j]; }
102+
bool Check() const { return true; }
103+
104+
private:
105+
FlatMatrix<ScalarType> matrix_;
106+
};
107+
93108
} // namespace operations_research
94109

95110
#endif // OR_TOOLS_UTIL_VECTOR_OR_FUNCTION_H_

0 commit comments

Comments
 (0)