Skip to content

Updates to CP #499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
028c53d
cleanup [skip ci]
evaleev Mar 10, 2023
e2fd8bd
bump btas tag
kmp5VT Aug 22, 2023
ffcf75b
Merge branch 'master' into kmp5/feature/CP
kmp5VT Dec 18, 2024
bc96de8
Revert btas tag
kmp5VT Dec 18, 2024
34fe780
Bump btas tag
kmp5VT Dec 18, 2024
65f75f1
bump btas tag
kmp5VT Dec 19, 2024
a52b369
Make a new CP ALS which takes the THC format
kmp5VT Dec 24, 2024
0565328
Bump btas tag
kmp5VT Jan 13, 2025
68260d5
Merge branch 'master' into kmp5/feature/CP
kmp5VT Jan 22, 2025
6ac5e35
Create a way to set the CP factor matrices of thc based CP
kmp5VT Jan 23, 2025
2bc3d90
Update THC based solver to work with unsymmetric tensors
kmp5VT Feb 23, 2025
fcdd1a6
Update CP to not force lambda into one of the factors
kmp5VT Feb 23, 2025
8b0465a
Return the correct error
kmp5VT Feb 28, 2025
fbe772a
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 13, 2025
62b7d1e
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 13, 2025
7adc805
fix btas tag
kmp5VT Mar 13, 2025
c3ef59e
Add a tests for the new CP solver
kmp5VT Mar 13, 2025
51b8cfa
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 17, 2025
6558ff8
Make it possible to set factor matrices instead of computing new ones
kmp5VT Mar 17, 2025
3e169c5
Merge branch 'kmp5/feature/CP' of https://github.com/ValeevGroup/tile…
kmp5VT Mar 17, 2025
297c302
lapack throws a std::exception not a runtime_error
kmp5VT Mar 21, 2025
51f89d0
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 21, 2025
63e0bcd
Use default world instead of provided world
kmp5VT Apr 1, 2025
664d6fa
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 2, 2025
4cd8296
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 4, 2025
2aa21f4
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 6, 2025
2a18dd0
Merge branch 'master' into kmp5/feature/CP
kmp5VT May 27, 2025
b140892
Bump btas tag
kmp5VT May 27, 2025
08df87c
Merge branch 'master' into kmp5/feature/CP
evaleev Jun 11, 2025
854b2ca
Format
kmp5VT Jun 23, 2025
c5a601a
format
kmp5VT Jun 23, 2025
2026832
format
kmp5VT Jun 23, 2025
7c94fd7
Merge branch 'master' into kmp5/feature/CP
kmp5VT Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions external/versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG bd84a52766ab497dedc2f15f2162fb0eb7ec4653)
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)

set(TA_TRACKED_BTAS_TAG 62d57d9b1e0c733b4b547bc9cfdd07047159dbca)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 1cfcb12647c768ccd83b098c64cda723e1275e49)
set(TA_TRACKED_BTAS_TAG 74dd9277c71043c564f931c9d1f07842548d2349)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 26646416e5f5829dc13d0d97fb15ae5c01b78e82)

set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece)
set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83)
Expand Down
2 changes: 2 additions & 0 deletions src/TiledArray/math/solvers/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@

#include <TiledArray/math/solvers/cp/cp.h>
#include <TiledArray/math/solvers/cp/cp_als.h>
#include <TiledArray/math/solvers/cp/cp_thc_als.h>
#include <TiledArray/math/solvers/cp/cp_reconstruct.h>

namespace TiledArray {
using TiledArray::math::cp::CP_ALS;
using TiledArray::math::cp::CP_THC_ALS;
using TiledArray::math::cp::cp_reconstruct;
} // namespace TiledArray

Expand Down
78 changes: 66 additions & 12 deletions src/TiledArray/math/solvers/cp/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,56 @@ static inline char intToAlphabet(int i) { return static_cast<char>('a' + i); }

} // namespace detail

/// normalizes "columns" (aka rows) of an updated factor matrix

/// rows of factor matrices produced by least-squares are not unit
/// normalized. This takes each row and makes it unit normalized,
/// with inverse of the normalization factor stored in this->lambda
/// \param[in,out] factor in: unnormalized factor matrix, out:
/// normalized factor matrix
template <typename Array>
void normalize_factor(Array& factor, Array& lambda) {
using Tile = typename Array::value_type;
auto& world = factor.world();
// this is what the code should look like, but expressions::einsum seems to
// be buggy lambda contains squared norms of rows
lambda = expressions::einsum(factor("r,n"), factor("r,n"), "r");

// element-wise square root to convert squared norms to norms
TiledArray::foreach_inplace(
lambda,
[](Tile& tile) {
auto lo = tile.range().lobound_data();
auto up = tile.range().upbound_data();
for (auto R = lo[0]; R < up[0]; ++R) {
const auto norm_squared_RR = tile({R});
using std::sqrt;
tile({R}) = sqrt(norm_squared_RR);
}
},
/* fence = */ true);
lambda.truncate();
lambda.make_replicated();
auto lambda_eig = array_to_eigen(lambda);

TiledArray::foreach_inplace(
factor,
[&lambda_eig](Tile& tile) {
auto lo = tile.range().lobound_data();
auto up = tile.range().upbound_data();
for (auto R = lo[0]; R < up[0]; ++R) {
const auto lambda_R = lambda_eig(R, 0);
if (lambda_R < 1e-12) continue;
auto scale_by = 1.0 / lambda_R;
for (auto N = lo[1]; N < up[1]; ++N) {
tile(R, N) *= scale_by;
}
}
},
/* fence = */ true);
factor.truncate();
}

/**
* This is a base class for the canonical polyadic (CP)
* decomposition solver. The decomposition, in general,
Expand Down Expand Up @@ -91,7 +141,7 @@ class CP {
/// \returns the fit: \f$ 1.0 - |T_{\text{exact}} - T_{\text{approx}} | \f$
double compute_rank(size_t rank, size_t rank_block_size = 0,
bool build_rank = false, double epsilonALS = 1e-3,
bool verbose = false) {
bool verbose = false, int niters = 100) {
rank_block_size = (rank_block_size == 0 ? rank : rank_block_size);
double epsilon = 1.0;
fit_tol = epsilonALS;
Expand All @@ -101,15 +151,15 @@ class CP {
do {
rank_trange = TiledRange1::make_uniform(cur_rank, rank_block_size);
build_guess(cur_rank, rank_trange);
ALS(cur_rank, 100, verbose);
ALS(cur_rank, niters, verbose);
++cur_rank;
} while (cur_rank < rank);
} else {
rank_trange = TiledRange1::make_uniform(rank, rank_block_size);
build_guess(rank, rank_trange);
ALS(rank, 100, verbose);
ALS(rank, niters, verbose);
}
return epsilon;
return this->final_fit;
}

/// This function computes the CP decomposition with an
Expand Down Expand Up @@ -140,10 +190,14 @@ class CP {
return epsilon;
}

std::vector<Array> get_factor_matrices() {
std::vector<Array> get_factor_matrices(bool with_lambda = false) {
TA_ASSERT(!cp_factors.empty(),
"CP factor matrices have not been computed)");
auto result = cp_factors;
if (with_lambda) {
result.emplace_back(lambda);
return result;
}
result.pop_back();
result.emplace_back(unNormalized_Factor);
return result;
Expand Down Expand Up @@ -185,7 +239,8 @@ class CP {
final_fit, // The final fit of the ALS
// optimization at fixed rank.
fit_tol, // Tolerance for the ALS solver
norm_reference; // used in determining the CP fit.
norm_reference, // used in determining the CP fit.
norm_ref_sq;
std::size_t converged_num =
0; // How many times the ALS solver
// has changed less than the tolerance in a row
Expand Down Expand Up @@ -259,7 +314,7 @@ class CP {
// MtKRP);
try {
MtKRP = math::linalg::cholesky_solve(W, MtKRP);
} catch (std::runtime_error& ex) {
} catch (std::exception& ex) {
// if W is near-singular try LU instead of Cholesky
if (std::string(ex.what()).find("lapack::posv failed") !=
std::string::npos) {
Expand Down Expand Up @@ -370,17 +425,16 @@ class CP {
for (size_t i = 1; i < ndim - 1; ++i, ++gram_ptr) {
W("r,rp") *= (*gram_ptr)("r,rp");
}
auto result = sqrt(W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n"))));
auto result = W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n")));
// not sure why need to fence here, but hang periodically without it
W.world().gop.fence();
return result;
};
// compute the error in the loss function and find the fit
const auto norm_cp = factor_norm(); // ||T_CP||_2
const auto squared_norm_error = norm_reference * norm_reference +
norm_cp * norm_cp -
2.0 * ref_dot_cp; // ||T - T_CP||_2^2
const auto squared_norm_error =
norm_ref_sq + norm_cp - 2.0 * ref_dot_cp; // ||T - T_CP||_2^2
// N.B. squared_norm_error is very noisy
// TA_ASSERT(squared_norm_error >= - 1e-8);
const auto norm_error = sqrt(abs(squared_norm_error));
Expand Down
9 changes: 9 additions & 0 deletions src/TiledArray/math/solvers/cp/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class CP_ALS : public CP<Tile, Policy> {
first_gemm_dim_last.pop_back();

this->norm_reference = norm2(tref);
this->norm_ref_sq = this->norm_reference * this->norm_reference;
}

void set_factor_matrices(std::vector<DistArray<Tile, Policy>>& factors) {
cp_factors = factors;
factors_set = true;
}

protected:
Expand All @@ -77,6 +83,7 @@ class CP_ALS : public CP<Tile, Policy> {
std::string ref_indices, first_gemm_dim_one, first_gemm_dim_last;
std::vector<typename Tile::value_type> lambda;
TiledRange1 rank_trange1;
bool factors_set = false;

/// This function constructs the initial CP facotr matrices
/// stores them in CP::cp_factors vector.
Expand All @@ -93,6 +100,8 @@ class CP_ALS : public CP<Tile, Policy> {
rank_trange, reference.trange().data()[i]);
cp_factors.emplace_back(factor);
}
} else if (factors_set) {
// Do nothing and don't throw an error.
} else {
TA_EXCEPTION("Currently no implementation to increase or change rank");
}
Expand Down
Loading
Loading