Skip to content

Commit 2da5b4e

Browse files
committed
use cholesky
1 parent d65be44 commit 2da5b4e

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

src/01_linear_algebra.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,36 @@ mat crossprod_(const mat &X, const vec &w) {
66
return Y.t() * Y;
77
}
88

9+
// vec solve_beta_(mat MX, const mat &MNU, const vec &w) {
10+
// const vec sqrt_w = sqrt(w);
11+
// MX.each_col() %= sqrt_w;
12+
13+
// mat Q, R;
14+
// if (!qr_econ(Q, R, MX)) {
15+
// stop("QR decomposition failed");
16+
// }
17+
18+
// return solve(trimatu(R), Q.t() * (MNU.each_col() % sqrt_w), solve_opts::fast);
19+
// }
20+
921
vec solve_beta_(mat MX, const mat &MNU, const vec &w) {
1022
const vec sqrt_w = sqrt(w);
1123
MX.each_col() %= sqrt_w;
24+
mat XtX = MX.t() * MX;
1225

13-
mat Q, R;
14-
if (!qr_econ(Q, R, MX)) {
15-
stop("QR decomposition failed");
26+
// Cholesky decomposition: XtX = L * L.t()
27+
mat L;
28+
if (!chol(L, XtX, "lower")) {
29+
stop("Cholesky decomposition failed.");
1630
}
1731

18-
return solve(trimatu(R), Q.t() * (MNU.each_col() % sqrt_w), solve_opts::fast);
32+
vec Xty = MX.t() * (MNU.each_col() % sqrt_w);
33+
34+
// Solve L * z = Xty
35+
vec z = solve(trimatl(L), Xty, solve_opts::fast);
36+
37+
// Solve L.t() * beta = z
38+
vec beta = solve(trimatu(L.t()), z, solve_opts::fast);
39+
40+
return beta;
1941
}

0 commit comments

Comments
 (0)