Skip to content

Commit c1eefa0

Browse files
authored
Merge pull request #524 from ValeevGroup/evaleev/fix/linalg-tests
fix `linalg` tests
2 parents 152dad4 + f881f72 commit c1eefa0

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

tests/linalg.cpp

+28-21
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <tiledarray.h>
22
#include <random>
33
#include "TiledArray/config.h"
4-
//#include "range_fixture.h"
54
#include "unit_test_config.h"
65

76
#include "TiledArray/math/linalg/non-distributed/cholesky.h"
@@ -469,26 +468,34 @@ BOOST_AUTO_TEST_CASE(heig_same_tiling) {
469468
return this->make_ta_reference(t, range);
470469
});
471470

472-
auto [evals, evecs] = non_dist::heig(ref_ta);
471+
auto [evals, evecs] = heig(ref_ta);
473472
auto [evals_non_dist, evecs_non_dist] = non_dist::heig(ref_ta);
474-
// auto evals = heig( ref_ta );
475473

476474
BOOST_CHECK(evecs.trange() == ref_ta.trange());
477475

478-
// check eigenvectors against non_dist only, for now ...
479-
decltype(evecs) evecs_error;
480-
evecs_error("i,j") = evecs_non_dist("i,j") - evecs("i,j");
481-
// TODO need to fix phases of the eigenvectors to be able to compare ...
482-
// BOOST_CHECK_SMALL(evecs_error("i,j").norm().get(),
483-
// N * N * std::numeric_limits<double>::epsilon());
484-
485476
// Check eigenvalue correctness
486477
double tol = N * N * std::numeric_limits<double>::epsilon();
487478
for (int64_t i = 0; i < N; ++i) {
488479
BOOST_CHECK_SMALL(std::abs(evals[i] - exact_evals[i]), tol);
489480
BOOST_CHECK_SMALL(std::abs(evals_non_dist[i] - exact_evals[i]), tol);
490481
}
491482

483+
// check eigenvectors by reconstruction
484+
auto reconstruction_check = [&](const auto& s, const auto& U,
485+
const auto str) {
486+
using Array = TA::TArray<double>;
487+
auto S =
488+
TA::diagonal_array<Array>(U.world(), U.trange(), s.begin(), s.end());
489+
Array err;
490+
err("i,j") = U("i,k") * S("k,l") * U("j,l").conj() - ref_ta("i,j");
491+
auto err_l2 = TA::norm2(err);
492+
const double epsilon = N * N * std::numeric_limits<double>::epsilon();
493+
BOOST_CHECK(err_l2 < epsilon);
494+
// std::cout << str << " ||U s U† - A||_2 = " << err_l2 << std::endl;
495+
};
496+
reconstruction_check(evals, evecs, "heig");
497+
reconstruction_check(evals_non_dist, evecs_non_dist, "non_dist::heig");
498+
492499
GlobalFixture::world->gop.fence();
493500
}
494501

@@ -576,7 +583,7 @@ BOOST_AUTO_TEST_CASE(cholesky) {
576583
return this->make_ta_reference(t, range);
577584
});
578585

579-
auto L = non_dist::cholesky(A);
586+
auto L = TiledArray::cholesky(A);
580587

581588
BOOST_CHECK(L.trange() == A.trange());
582589

@@ -729,7 +736,7 @@ BOOST_AUTO_TEST_CASE(cholesky_lsolve) {
729736
});
730737

731738
// Should produce X = L**H
732-
auto [L, X] = non_dist::cholesky_lsolve(TA::NoTranspose, A, A);
739+
auto [L, X] = TiledArray::cholesky_lsolve(TA::NoTranspose, A, A);
733740
BOOST_CHECK(X.trange() == A.trange());
734741
BOOST_CHECK(L.trange() == A.trange());
735742

@@ -797,7 +804,7 @@ BOOST_AUTO_TEST_CASE(lu_solve) {
797804
return this->make_ta_reference(t, range);
798805
});
799806

800-
auto iden = non_dist::lu_solve(ref_ta, ref_ta);
807+
auto iden = TiledArray::lu_solve(ref_ta, ref_ta);
801808

802809
BOOST_CHECK(iden.trange() == ref_ta.trange());
803810

@@ -834,7 +841,7 @@ BOOST_AUTO_TEST_CASE(lu_inv) {
834841

835842
TA::TArray<double> iden(*GlobalFixture::world, trange);
836843

837-
auto Ainv = non_dist::lu_inv(ref_ta);
844+
auto Ainv = TiledArray::lu_inv(ref_ta);
838845
iden("i,j") = Ainv("i,k") * ref_ta("k,j");
839846

840847
BOOST_CHECK(iden.trange() == ref_ta.trange());
@@ -871,7 +878,7 @@ BOOST_AUTO_TEST_CASE(svd_values_only) {
871878
return this->make_ta_reference(t, range);
872879
});
873880

874-
auto S = non_dist::svd<TA::SVD::ValuesOnly>(ref_ta, trange, trange);
881+
auto S = svd<TA::SVD::ValuesOnly>(ref_ta, trange, trange);
875882

876883
std::vector exact_singular_values = exact_evals;
877884
std::sort(exact_singular_values.begin(), exact_singular_values.end(),
@@ -895,7 +902,7 @@ BOOST_AUTO_TEST_CASE(svd_leftvectors) {
895902
return this->make_ta_reference(t, range);
896903
});
897904

898-
auto [S, U] = non_dist::svd<TA::SVD::LeftVectors>(ref_ta, trange, trange);
905+
auto [S, U] = svd<TA::SVD::LeftVectors>(ref_ta, trange, trange);
899906

900907
std::vector exact_singular_values = exact_evals;
901908
std::sort(exact_singular_values.begin(), exact_singular_values.end(),
@@ -919,7 +926,7 @@ BOOST_AUTO_TEST_CASE(svd_rightvectors) {
919926
return this->make_ta_reference(t, range);
920927
});
921928

922-
auto [S, VT] = non_dist::svd<TA::SVD::RightVectors>(ref_ta, trange, trange);
929+
auto [S, VT] = svd<TA::SVD::RightVectors>(ref_ta, trange, trange);
923930

924931
std::vector exact_singular_values = exact_evals;
925932
std::sort(exact_singular_values.begin(), exact_singular_values.end(),
@@ -943,7 +950,7 @@ BOOST_AUTO_TEST_CASE(svd_allvectors) {
943950
return this->make_ta_reference(t, range);
944951
});
945952

946-
auto [S, U, VT] = non_dist::svd<TA::SVD::AllVectors>(ref_ta, trange, trange);
953+
auto [S, U, VT] = svd<TA::SVD::AllVectors>(ref_ta, trange, trange);
947954

948955
std::vector exact_singular_values = exact_evals;
949956
std::sort(exact_singular_values.begin(), exact_singular_values.end(),
@@ -985,7 +992,7 @@ void householder_qr_test(const ArrayT& A, double tol) {
985992
: non_dist::householder_qr<false>(A);
986993
#else
987994
static_assert(not use_scalapack);
988-
auto [Q, R] = non_dist::householder_qr<false>(A);
995+
auto [Q, R] = householder_qr<false>(A);
989996
#endif
990997

991998
// Check reconstruction error
@@ -1046,7 +1053,7 @@ template <typename ArrayT>
10461053
void cholesky_qr_q_only_test(const ArrayT& A, double tol) {
10471054
using value_type = typename ArrayT::element_type;
10481055

1049-
auto Q = TiledArray::math::linalg::cholesky_qr<true>(A);
1056+
auto Q = TiledArray::cholesky_qr<true>(A);
10501057

10511058
// Make sure the Q is orthogonal at least
10521059
TA::TArray<double> Iden;
@@ -1059,7 +1066,7 @@ void cholesky_qr_q_only_test(const ArrayT& A, double tol) {
10591066

10601067
template <typename ArrayT>
10611068
void cholesky_qr_test(const ArrayT& A, double tol) {
1062-
auto [Q, R] = TiledArray::math::linalg::cholesky_qr<false>(A);
1069+
auto [Q, R] = TiledArray::cholesky_qr<false>(A);
10631070

10641071
// Check reconstruction error
10651072
TA::TArray<double> QR_ERROR;

0 commit comments

Comments
 (0)