Skip to content

Commit

Permalink
rework on ddm solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMarchand20 committed Jul 26, 2024
1 parent fb27f52 commit 93821de
Show file tree
Hide file tree
Showing 13 changed files with 1,140 additions and 413 deletions.
4 changes: 3 additions & 1 deletion include/htool/hmatrix/hmatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ class HMatrix : public TreeNode<HMatrix<CoefficientPrecision, CoordinatePrecisio
HMatrix(const HMatrix &parent, const Cluster<CoordinatePrecision> *target_cluster, const Cluster<CoordinatePrecision> *source_cluster) : TreeNode<HMatrix, HMatrixTreeData<CoefficientPrecision, CoordinatePrecision>>(parent), m_target_cluster(target_cluster), m_source_cluster(source_cluster) {}

HMatrix(const HMatrix &rhs) : TreeNode<HMatrix<CoefficientPrecision, CoordinatePrecision>, HMatrixTreeData<CoefficientPrecision, CoordinatePrecision>>(rhs), m_target_cluster(rhs.m_target_cluster), m_source_cluster(rhs.m_source_cluster), m_symmetry(rhs.m_symmetry), m_UPLO(rhs.m_UPLO), m_leaves(), m_leaves_for_symmetry(), m_symmetry_type_for_leaves(), m_storage_type(rhs.m_storage_type) {
Logger::get_instance().log(LogLevel::INFO, "Deep copy of HMatrix");
if (m_target_cluster->is_root() or is_cluster_on_partition(*m_target_cluster)) {
Logger::get_instance().log(LogLevel::INFO, "Deep copy of HMatrix");
}
this->m_depth = rhs.m_depth;
this->m_is_root = rhs.m_is_root;
this->m_tree_data = std::make_shared<HMatrixTreeData<CoefficientPrecision, CoordinatePrecision>>(*rhs.m_tree_data);
Expand Down
2 changes: 1 addition & 1 deletion include/htool/hmatrix/tree_builder/tree_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace htool {

template <typename CoefficientPrecision, typename CoordinatePrecision>
template <typename CoefficientPrecision, typename CoordinatePrecision=underlying_type<CoefficientPrecision>>
class HMatrixTreeBuilder {
private:
class ZeroGenerator : public VirtualGenerator<CoefficientPrecision> {
Expand Down
464 changes: 184 additions & 280 deletions include/htool/solvers/ddm.hpp

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions include/htool/solvers/interfaces/virtual_local_solver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef HTOOL_WRAPPERS_INTERFACE_HPP
#define HTOOL_WRAPPERS_INTERFACE_HPP
#define HPDDM_NUMBERING 'F'
#define HPDDM_DENSE 1
#define HPDDM_FETI 0
#define HPDDM_BDD 0
#define LAPACKSUB
#define DLAPACK
#define EIGENSOLVER 1
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wsign-compare"
# pragma clang diagnostic ignored "-Wshadow"
# pragma clang diagnostic ignored "-Wdouble-promotion"
# pragma clang diagnostic ignored "-Wunused-parameter"
# pragma clang diagnostic ignored "-Wnon-virtual-dtor"
#elif defined(__GNUC__) || defined(__GNUG__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wsign-compare"
# pragma GCC diagnostic ignored "-Wshadow"
# pragma GCC diagnostic ignored "-Wdouble-promotion"
# pragma GCC diagnostic ignored "-Wunused-parameter"
# pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
# pragma GCC diagnostic ignored "-Wuseless-cast"
# pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#endif

#include <HPDDM.hpp>

#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__) || defined(__GNUG__)
# pragma GCC diagnostic pop
#endif

namespace htool {

template <typename CoefficientPrecision>
class VirtualLocalSolver {
public:
virtual void numfact(HPDDM::MatrixCSR<CoefficientPrecision> *const &, bool = false, CoefficientPrecision *const & = nullptr) = 0;

virtual void solve(CoefficientPrecision *const, const unsigned short & = 1) const = 0;
virtual void solve(const CoefficientPrecision *const, CoefficientPrecision *const, const unsigned short & = 1) const = 0;

virtual ~VirtualLocalSolver() {}
};

} // namespace htool
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#ifndef HTOOL_SOLVERS_LOCAL_SOLVERS_HMATRIX_PLUS_OVERLAP_HPP
#define HTOOL_SOLVERS_LOCAL_SOLVERS_HMATRIX_PLUS_OVERLAP_HPP

#include "../../hmatrix/linalg/factorization.hpp"
#include "../interfaces/virtual_local_solver.hpp"

namespace htool {

template <typename CoefficientPrecision, typename CoordinatePrecision>
class LocalHMatrixPlusOverlapSolver : public VirtualLocalSolver<CoefficientPrecision> {
private:
HMatrix<CoefficientPrecision, CoordinatePrecision> &m_local_hmatrix;
Matrix<CoefficientPrecision> &m_B, &m_C, &m_D;
mutable Matrix<CoefficientPrecision> buffer;

public:
LocalHMatrixPlusOverlapSolver(HMatrix<CoefficientPrecision> &local_hmatrix, Matrix<CoefficientPrecision> &B, Matrix<CoefficientPrecision> &C, Matrix<CoefficientPrecision> &D) : m_local_hmatrix(local_hmatrix), m_B(B), m_C(C), m_D(D) {}
void numfact(HPDDM::MatrixCSR<CoefficientPrecision> *const &, bool = false, CoefficientPrecision *const & = nullptr) {
if (m_local_hmatrix.get_symmetry() == 'N') {
lu_factorization(m_local_hmatrix);
if (m_C.nb_rows() > 0) {
triangular_hmatrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_local_hmatrix, m_B);
triangular_hmatrix_matrix_solve('R', 'U', 'N', 'N', CoefficientPrecision(1), m_local_hmatrix, m_C);
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, m_B, CoefficientPrecision(1), m_D);
lu_factorization(m_D);
}

} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
cholesky_factorization(m_local_hmatrix.get_UPLO(), m_local_hmatrix);
if (m_C.nb_rows() > 0) {
triangular_hmatrix_matrix_solve('R', m_local_hmatrix.get_UPLO(), is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(1), m_local_hmatrix, m_C);
add_matrix_matrix_product('N', is_complex<CoefficientPrecision>() ? 'C' : 'T', CoefficientPrecision(-1), m_C, m_C, CoefficientPrecision(1), m_D);
cholesky_factorization(m_local_hmatrix.get_UPLO(), m_D);
}
}
}
void solve(CoefficientPrecision *const b, const unsigned short &mu = 1) const {
int local_size_wo_overlap = m_local_hmatrix.get_target_cluster().get_size();
int size_overlap = m_C.nb_rows();
int local_size_w_overlap = local_size_wo_overlap + size_overlap;
Matrix<CoefficientPrecision> b1(local_size_wo_overlap, mu), b2(size_overlap, mu);

for (int i = 0; i < mu; i++) {
std::copy_n(b + i * local_size_w_overlap, local_size_wo_overlap, b1.data() + i * local_size_wo_overlap);
std::copy_n(b + i * local_size_w_overlap + local_size_wo_overlap, size_overlap, b2.data() + i * size_overlap);
}

if (m_local_hmatrix.get_symmetry() == 'N') {

triangular_hmatrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_local_hmatrix, b1);
if (m_C.nb_rows() > 0) {
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, b1, CoefficientPrecision(1), b2);
triangular_matrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_D, b2);

triangular_matrix_matrix_solve('L', 'U', 'N', 'N', CoefficientPrecision(1), m_D, b2);
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_B, b2, CoefficientPrecision(1), b1);
}
triangular_hmatrix_matrix_solve('L', 'U', 'N', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);

} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
triangular_hmatrix_matrix_solve('L', 'L', 'N', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);

if (m_C.nb_rows() > 0) {
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, b1, CoefficientPrecision(1), b2);
triangular_matrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_D, b2);

triangular_matrix_matrix_solve('L', 'L', is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(1), m_D, b2);
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, b2, CoefficientPrecision(1), b1);
}
triangular_hmatrix_matrix_solve('L', 'L', is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);
}

for (int i = 0; i < mu; i++) {
std::copy_n(b1.data() + i * local_size_wo_overlap, local_size_wo_overlap, b + i * local_size_w_overlap);
std::copy_n(b2.data() + i * size_overlap, size_overlap, b + i * local_size_w_overlap + local_size_wo_overlap);
}
}
void solve(const CoefficientPrecision *const b, CoefficientPrecision *const x, const unsigned short &mu = 1) const {

int local_size_wo_overlap = m_local_hmatrix.get_target_cluster().get_size();
int size_overlap = m_C.nb_rows();
int local_size_w_overlap = local_size_wo_overlap + size_overlap;
Matrix<CoefficientPrecision> b1(local_size_wo_overlap, mu), b2(size_overlap, mu);

for (int i = 0; i < mu; i++) {
std::copy_n(b + i * local_size_w_overlap, local_size_wo_overlap, b1.data() + i * local_size_wo_overlap);
std::copy_n(b + i * local_size_w_overlap + local_size_wo_overlap, size_overlap, b2.data() + i * size_overlap);
}

if (m_local_hmatrix.get_symmetry() == 'N') {

triangular_hmatrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_local_hmatrix, b1);

if (m_C.nb_rows() > 0) {
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, b1, CoefficientPrecision(1), b2);
triangular_matrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_D, b2);
triangular_matrix_matrix_solve('L', 'U', 'N', 'N', CoefficientPrecision(1), m_D, b2);
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_B, b2, CoefficientPrecision(1), b1);
}
triangular_hmatrix_matrix_solve('L', 'U', 'N', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);

} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
triangular_hmatrix_matrix_solve('L', 'L', 'N', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);
if (m_C.nb_rows() > 0) {
add_matrix_matrix_product('N', 'N', CoefficientPrecision(-1), m_C, b1, CoefficientPrecision(1), b2);
triangular_matrix_matrix_solve('L', 'L', 'N', 'U', CoefficientPrecision(1), m_D, b2);

triangular_matrix_matrix_solve('L', 'L', is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(1), m_D, b2);
add_matrix_matrix_product(is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(-1), m_C, b2, CoefficientPrecision(1), b1);
}
triangular_hmatrix_matrix_solve('L', 'L', is_complex<CoefficientPrecision>() ? 'C' : 'T', 'N', CoefficientPrecision(1), m_local_hmatrix, b1);
}
// std::cout << "TEST " << get_max(b1) << " " << get_max(b2) << "\n";
for (int i = 0; i < mu; i++) {
std::copy_n(b1.data() + i * local_size_wo_overlap, local_size_wo_overlap, x + i * local_size_w_overlap);
std::copy_n(b2.data() + i * size_overlap, size_overlap, x + i * local_size_w_overlap + local_size_wo_overlap);
}
}
};
} // namespace htool
#endif
82 changes: 82 additions & 0 deletions include/htool/solvers/local_solvers/local_hmatrix_solvers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef HTOOL_SOLVERS_LOCAL_SOLVERS_HMATRIX_HPP
#define HTOOL_SOLVERS_LOCAL_SOLVERS_HMATRIX_HPP

#include "../../hmatrix/linalg/factorization.hpp"
#include "../interfaces/virtual_local_solver.hpp"

namespace htool {

template <typename CoefficientPrecision, typename CoordinatePrecision>
class LocalHMatrixSolver : public VirtualLocalSolver<CoefficientPrecision> {
private:
HMatrix<CoefficientPrecision, CoordinatePrecision> &m_local_hmatrix;
bool m_is_using_permutation;
mutable Matrix<CoefficientPrecision> buffer;

public:
LocalHMatrixSolver(HMatrix<CoefficientPrecision> &local_hmatrix, bool is_using_permutation) : m_local_hmatrix(local_hmatrix), m_is_using_permutation(is_using_permutation) {}
void numfact(HPDDM::MatrixCSR<CoefficientPrecision> *const &, bool = false, CoefficientPrecision *const & = nullptr) {
if (m_local_hmatrix.get_symmetry() == 'N') {
lu_factorization(m_local_hmatrix);
} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
cholesky_factorization(m_local_hmatrix.get_UPLO(), m_local_hmatrix);
}
}
void solve(CoefficientPrecision *const b, const unsigned short &mu = 1) const {

if (m_is_using_permutation) {
if (buffer.nb_rows() != m_local_hmatrix.nb_cols() or buffer.nb_cols() != mu) {
buffer.resize(m_local_hmatrix.nb_cols(), mu);
}

auto &source_cluster = m_local_hmatrix.get_source_cluster();
for (int i = 0; i < mu; i++) {
user_to_cluster(source_cluster, b + source_cluster.get_size() * i, buffer.data() + source_cluster.get_size() * i);
}
} else {
buffer.assign(m_local_hmatrix.nb_cols(), mu, b, false);
}

if (m_local_hmatrix.get_symmetry() == 'N') {
lu_solve('N', m_local_hmatrix, buffer);
} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
cholesky_solve(m_local_hmatrix.get_UPLO(), m_local_hmatrix, buffer);
}

if (m_is_using_permutation) {
auto &target_cluster = m_local_hmatrix.get_target_cluster();
for (int i = 0; i < mu; i++) {
cluster_to_user(target_cluster, buffer.data() + target_cluster.get_size() * i, b + target_cluster.get_size() * i);
}
}
}
void solve(const CoefficientPrecision *const b, CoefficientPrecision *const x, const unsigned short &mu = 1) const {
if (buffer.nb_rows() != m_local_hmatrix.nb_cols() or buffer.nb_cols() != mu) {
buffer.resize(m_local_hmatrix.nb_cols(), mu);
}
if (m_is_using_permutation) {
auto &source_cluster = m_local_hmatrix.get_source_cluster();
for (int i = 0; i < mu; i++) {
user_to_cluster(source_cluster, b + source_cluster.get_size() * i, buffer.data() + source_cluster.get_size() * i);
}
} else {
buffer.assign(m_local_hmatrix.nb_cols(), mu, x, false);
std::copy_n(b, mu * m_local_hmatrix.nb_cols(), buffer.data());
}

if (m_local_hmatrix.get_symmetry() == 'N') {
lu_solve('N', m_local_hmatrix, buffer);
} else if (m_local_hmatrix.get_symmetry() == 'S' || m_local_hmatrix.get_symmetry() == 'H') {
cholesky_solve(m_local_hmatrix.get_UPLO(), m_local_hmatrix, buffer);
}

if (m_is_using_permutation) {
auto &target_cluster = m_local_hmatrix.get_target_cluster();
for (int i = 0; i < mu; i++) {
cluster_to_user(target_cluster, buffer.data() + target_cluster.get_size() * i, x + target_cluster.get_size() * i);
}
}
}
};
} // namespace htool
#endif
Loading

0 comments on commit 93821de

Please sign in to comment.