diff --git a/include/htool/hmatrix/hmatrix.hpp b/include/htool/hmatrix/hmatrix.hpp index 2d1d8b65..f7ba0566 100644 --- a/include/htool/hmatrix/hmatrix.hpp +++ b/include/htool/hmatrix/hmatrix.hpp @@ -209,10 +209,14 @@ class HMatrix : public TreeNode> ptr) { this->m_tree_data->m_admissibility_condition = ptr; } void set_minimal_target_depth(unsigned int minimal_target_depth) { this->m_tree_data->m_minimal_target_depth = minimal_target_depth; } void set_minimal_source_depth(unsigned int minimal_source_depth) { this->m_tree_data->m_minimal_source_depth = minimal_source_depth; } + void set_block_tree_consistency(bool consistency) { + this->m_tree_data->m_is_block_tree_consistent = consistency; + } // HMatrix Tree setters char get_symmetry_for_leaves() const { return m_symmetry_type_for_leaves; } char get_UPLO_for_leaves() const { return m_UPLO_for_leaves; } + bool is_block_tree_consistent() const { return this->m_tree_data->m_is_block_tree_consistent; } // Data computation void compute_dense_data(const VirtualGenerator &generator) { diff --git a/include/htool/hmatrix/hmatrix_tree_data.hpp b/include/htool/hmatrix/hmatrix_tree_data.hpp index 5f041818..0624736e 100644 --- a/include/htool/hmatrix/hmatrix_tree_data.hpp +++ b/include/htool/hmatrix/hmatrix_tree_data.hpp @@ -22,6 +22,7 @@ struct HMatrixTreeData { unsigned int m_minimal_target_depth{0}; bool m_delay_dense_computation{false}; int m_reqrank{-1}; + bool m_is_block_tree_consistent{true}; // Information mutable std::map m_information; diff --git a/include/htool/hmatrix/linalg/factorization.hpp b/include/htool/hmatrix/linalg/factorization.hpp index e6cd9c9a..3c4caecc 100644 --- a/include/htool/hmatrix/linalg/factorization.hpp +++ b/include/htool/hmatrix/linalg/factorization.hpp @@ -16,6 +16,10 @@ namespace htool { template > void lu_factorization(HMatrix &hmatrix) { + if (!hmatrix.is_block_tree_consistent()) { + htool::Logger::get_instance().log(LogLevel::ERROR, "lu_factorization is only implemented for consistent block tree."); // LCOV_EXCL_LINE + } + if (hmatrix.is_hierarchical()) { bool block_tree_not_consistent = (hmatrix.get_target_cluster().get_rank() < 0 || hmatrix.get_source_cluster().get_rank() < 0); @@ -83,6 +87,10 @@ void internal_lu_solve(char trans, const HMatrix> void cholesky_factorization(char UPLO, HMatrix &hmatrix) { + if (!hmatrix.is_block_tree_consistent()) { + htool::Logger::get_instance().log(LogLevel::ERROR, "cholesky_factorization is only implemented for consistent block tree."); // LCOV_EXCL_LINE + } + if (hmatrix.is_hierarchical()) { bool block_tree_not_consistent = (hmatrix.get_target_cluster().get_rank() < 0 || hmatrix.get_source_cluster().get_rank() < 0); diff --git a/include/htool/hmatrix/tree_builder/tree_builder.hpp b/include/htool/hmatrix/tree_builder/tree_builder.hpp index 3a8a796b..1eeba7d2 100644 --- a/include/htool/hmatrix/tree_builder/tree_builder.hpp +++ b/include/htool/hmatrix/tree_builder/tree_builder.hpp @@ -57,6 +57,7 @@ class HMatrixTreeBuilder { std::shared_ptr> m_low_rank_generator; std::shared_ptr> m_admissibility_condition; std::shared_ptr> m_dense_blocks_generator; + bool m_is_block_tree_consistent{true}; // Internal methods void build_block_tree(HMatrixType *current_hmatrix) const; @@ -162,6 +163,7 @@ class HMatrixTreeBuilder { void set_minimal_source_depth(int minimal_source_depth) { m_minsourcedepth = minimal_source_depth; } void set_minimal_target_depth(int minimal_target_depth) { m_mintargetdepth = minimal_target_depth; } void set_dense_blocks_generator(std::shared_ptr> dense_blocks_generator) { m_dense_blocks_generator = dense_blocks_generator; } + void set_block_tree_consistency(bool consistency) { m_is_block_tree_consistent = consistency; } // Getters char get_symmetry() const { return m_symmetry_type; } @@ -180,6 +182,7 @@ HMatrix HMatrixTreeBuilder::build_block_ build_block_tree(hmatrix_child); } } + } else if (!m_is_block_tree_consistent && (source_cluster.get_size() > target_cluster.get_size())) { + HMatrixType *hmatrix_child = nullptr; + for (const auto &source_child : source_children) { + if ((is_target_cluster_in_target_partition(target_cluster) || target_cluster.get_rank() < 0) && !is_removed_by_symmetry(target_cluster, *source_child)) { + hmatrix_child = current_hmatrix->add_child(&target_cluster, source_child.get()); + set_hmatrix_symmetry(*hmatrix_child); + build_block_tree(hmatrix_child); + } + } + } else if (!m_is_block_tree_consistent && (target_cluster.get_size() > source_cluster.get_size())) { + HMatrixType *hmatrix_child = nullptr; + for (const auto &target_child : target_children) { + if ((is_target_cluster_in_target_partition(*target_child) || target_child->get_rank() < 0) && !is_removed_by_symmetry(*target_child, source_cluster)) { + hmatrix_child = current_hmatrix->add_child(target_child.get(), &source_cluster); + set_hmatrix_symmetry(*hmatrix_child); + build_block_tree(hmatrix_child); + } + } } else { HMatrixType *hmatrix_child = nullptr; for (const auto &target_child : target_children) { diff --git a/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_complex_double.cpp b/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_complex_double.cpp index 0bcfde1e..7f14656d 100644 --- a/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_complex_double.cpp +++ b/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_complex_double.cpp @@ -19,14 +19,16 @@ int main(int argc, char *argv[]) { for (auto use_local_cluster : {true, false}) { for (auto epsilon : {1e-14, 1e-6}) { for (auto use_dense_block_generator : {true, false}) { - std::cout << nr << " " << nc << " " << use_local_cluster << " " << epsilon << " " << use_dense_block_generator << "\n"; + for (auto block_tree_consistency : {true, false}) { + std::cout << nr << " " << nc << " " << use_local_cluster << " " << epsilon << " " << use_dense_block_generator << " " << block_tree_consistency << "\n"; - is_error = is_error || test_hmatrix_build, GeneratorTestComplexSymmetric>(nr, nc, use_local_cluster, 'N', 'N', epsilon, use_dense_block_generator); - if (nr == nc) { - for (auto UPLO : {'U', 'L'}) { - is_error = is_error || test_hmatrix_build, GeneratorTestComplexSymmetric>(nr, nr, use_local_cluster, 'S', UPLO, epsilon, use_dense_block_generator); + is_error = is_error || test_hmatrix_build, GeneratorTestComplexSymmetric>(nr, nc, use_local_cluster, 'N', 'N', epsilon, use_dense_block_generator, block_tree_consistency); + if (nr == nc) { + for (auto UPLO : {'U', 'L'}) { + is_error = is_error || test_hmatrix_build, GeneratorTestComplexSymmetric>(nr, nr, use_local_cluster, 'S', UPLO, epsilon, use_dense_block_generator, block_tree_consistency); - is_error = is_error || test_hmatrix_build, GeneratorTestComplexHermitian>(nr, nr, use_local_cluster, 'H', UPLO, epsilon, use_dense_block_generator); + is_error = is_error || test_hmatrix_build, GeneratorTestComplexHermitian>(nr, nr, use_local_cluster, 'H', UPLO, epsilon, use_dense_block_generator, block_tree_consistency); + } } } } diff --git a/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_double.cpp b/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_double.cpp index 5bce3740..142f7b4d 100644 --- a/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_double.cpp +++ b/tests/functional_tests/hmatrix/hmatrix_build/test_hmatrix_build_double.cpp @@ -18,14 +18,16 @@ int main(int argc, char *argv[]) { for (auto use_local_cluster : {true, false}) { for (auto epsilon : {1e-14, 1e-6}) { for (auto use_dense_block_generator : {true, false}) { - std::cout << nr << " " << nc << " " << use_local_cluster << " " << epsilon << " " << use_dense_block_generator << "\n"; + for (auto block_tree_consistency : {true, false}) { + std::cout << nr << " " << nc << " " << use_local_cluster << " " << epsilon << " " << use_dense_block_generator << " " << block_tree_consistency << "\n"; - is_error = is_error || test_hmatrix_build(nr, nc, use_local_cluster, 'N', 'N', epsilon, use_dense_block_generator); + is_error = is_error || test_hmatrix_build(nr, nc, use_local_cluster, 'N', 'N', epsilon, use_dense_block_generator, block_tree_consistency); - if (nr == nc) { - for (auto UPLO : {'U', 'L'}) { - std::cout << UPLO << "\n"; - is_error = is_error || test_hmatrix_build(nr, nc, use_local_cluster, 'S', UPLO, epsilon, use_dense_block_generator); + if (nr == nc) { + for (auto UPLO : {'U', 'L'}) { + std::cout << UPLO << "\n"; + is_error = is_error || test_hmatrix_build(nr, nc, use_local_cluster, 'S', UPLO, epsilon, use_dense_block_generator, block_tree_consistency); + } } } } diff --git a/tests/functional_tests/hmatrix/test_hmatrix_build.hpp b/tests/functional_tests/hmatrix/test_hmatrix_build.hpp index 59c81ebb..bcb78e0a 100644 --- a/tests/functional_tests/hmatrix/test_hmatrix_build.hpp +++ b/tests/functional_tests/hmatrix/test_hmatrix_build.hpp @@ -28,7 +28,7 @@ using namespace std; using namespace htool; template -bool test_hmatrix_build(int nr, int nc, bool use_local_cluster, char Symmetry, char UPLO, htool::underlying_type epsilon, bool use_dense_blocks_generator) { +bool test_hmatrix_build(int nr, int nc, bool use_local_cluster, char Symmetry, char UPLO, htool::underlying_type epsilon, bool use_dense_blocks_generator, bool block_tree_consistency) { // Get the number of processes int sizeWorld; @@ -87,6 +87,7 @@ bool test_hmatrix_build(int nr, int nc, bool use_local_cluster, char Symmetry, c } else { hmatrix_tree_builder = std::make_unique>>(*target_root_cluster, *source_root_cluster, epsilon, eta, Symmetry, UPLO, -1, rankWorld, rankWorld); } + hmatrix_tree_builder->set_block_tree_consistency(block_tree_consistency); std::shared_ptr> dense_blocks_generator; if (use_dense_blocks_generator) {