@@ -6,14 +6,36 @@ mat crossprod_(const mat &X, const vec &w) {
6
6
return Y.t () * Y;
7
7
}
8
8
9
+ // vec solve_beta_(mat MX, const mat &MNU, const vec &w) {
10
+ // const vec sqrt_w = sqrt(w);
11
+ // MX.each_col() %= sqrt_w;
12
+
13
+ // mat Q, R;
14
+ // if (!qr_econ(Q, R, MX)) {
15
+ // stop("QR decomposition failed");
16
+ // }
17
+
18
+ // return solve(trimatu(R), Q.t() * (MNU.each_col() % sqrt_w), solve_opts::fast);
19
+ // }
20
+
9
21
vec solve_beta_ (mat MX, const mat &MNU, const vec &w) {
10
22
const vec sqrt_w = sqrt (w);
11
23
MX.each_col () %= sqrt_w;
24
+ mat XtX = MX.t () * MX;
12
25
13
- mat Q, R;
14
- if (!qr_econ (Q, R, MX)) {
15
- stop (" QR decomposition failed" );
26
+ // Cholesky decomposition: XtX = L * L.t()
27
+ mat L;
28
+ if (!chol (L, XtX, " lower" )) {
29
+ stop (" Cholesky decomposition failed." );
16
30
}
17
31
18
- return solve (trimatu (R), Q.t () * (MNU.each_col () % sqrt_w), solve_opts::fast);
32
+ vec Xty = MX.t () * (MNU.each_col () % sqrt_w);
33
+
34
+ // Solve L * z = Xty
35
+ vec z = solve (trimatl (L), Xty, solve_opts::fast);
36
+
37
+ // Solve L.t() * beta = z
38
+ vec beta = solve (trimatu (L.t ()), z, solve_opts::fast);
39
+
40
+ return beta;
19
41
}
0 commit comments