1
- // 02_center_variables.cpp (refactored using Armadillo types)
2
1
#include " 00_main.h"
3
2
4
3
// Method of alternating projections (Halperin)
@@ -7,16 +6,16 @@ void center_variables_(mat &V, const vec &w, const list &klist,
7
6
const int &iter_interrupt, const int &iter_ssr) {
8
7
// Auxiliary variables (fixed)
9
8
const size_t I = static_cast <size_t >(max_iter), N = V.n_rows , P = V.n_cols ,
10
- K = klist.size (),
11
- iter_check_interrupt0 = static_cast <size_t >(iter_interrupt),
12
- iter_check_ssr0 = static_cast <size_t >(iter_ssr);
9
+ K = klist.size (), iint0 = static_cast <size_t >(iter_interrupt),
10
+ isr0 = static_cast <size_t >(iter_ssr);
13
11
const double inv_sw = 1.0 / accu (w);
14
12
15
13
// Auxiliary variables (storage)
16
- size_t iter, j, k, p, J, iter_check_interrupt = iter_check_interrupt0,
17
- iter_check_ssr = iter_check_ssr0;
18
- double coef, xbar, ratio, ssr, ssq, ratio0, ssr0;
14
+ size_t iter, iint, isr, j, jj, k, n, p, J, JJ;
15
+ double num, coef, xbar, ratio, ssr, ssq, ratio0, ssr0;
19
16
vec x (N), x0 (N), Gx (N), G2x (N), deltaG (N), delta2 (N);
17
+
18
+ // Precompute groups into fields
20
19
field<field<uvec>> group_indices (K);
21
20
field<vec> group_inverse_weights (K);
22
21
for (k = 0 ; k < K; ++k) {
@@ -26,59 +25,60 @@ void center_variables_(mat &V, const vec &w, const list &klist,
26
25
vec invs (J);
27
26
for (j = 0 ; j < J; ++j) {
28
27
idxs (j) = as_uvec (as_cpp<integers>(jlist[j]));
29
- ;
30
28
invs (j) = 1.0 / accu (w.elem (idxs (j)));
31
29
}
32
- group_indices (k) = idxs;
33
- group_inverse_weights (k) = invs;
30
+ group_indices (k) = std::move ( idxs) ;
31
+ group_inverse_weights (k) = std::move ( invs) ;
34
32
}
35
33
34
+ // Single nested‐field projection helper
35
+ auto project = [&](vec &v) {
36
+ J = group_indices.n_elem ;
37
+ for (j = 0 ; j < J; ++j) {
38
+ auto &idxs = group_indices (j);
39
+ auto &invs = group_inverse_weights (j);
40
+ JJ = idxs.n_elem ;
41
+ for (jj = 0 ; jj < JJ; ++jj) {
42
+ const uvec &coords = idxs (jj);
43
+ xbar = dot (w.elem (coords), v.elem (coords)) * invs (jj);
44
+ v.elem (coords) -= xbar;
45
+ }
46
+ }
47
+ };
48
+
49
+ // Column‐wise centering
36
50
for (p = 0 ; p < P; ++p) {
37
51
x = V.col (p);
38
52
ratio0 = std::numeric_limits<double >::infinity ();
39
53
ssr0 = std::numeric_limits<double >::infinity ();
40
54
55
+ // reset per‐column interrupt
56
+ iint = iint0;
57
+ isr = isr0;
58
+
41
59
for (iter = 0 ; iter < I; ++iter) {
42
- if (iter == iter_check_interrupt ) {
60
+ if (iter == iint ) {
43
61
check_user_interrupt ();
44
- iter_check_interrupt += iter_check_interrupt0 ;
62
+ iint += iint0 ;
45
63
}
46
64
47
65
x0 = x;
48
66
49
- // Halperin projection
50
- for (k = 0 ; k < K; ++k) {
51
- field<uvec> &idxs = group_indices (k);
52
- J = idxs.n_elem ;
53
- vec &invs = group_inverse_weights (k);
54
- for (j = 0 ; j < J; ++j) {
55
- const uvec &coords = idxs (j);
56
- xbar = dot (w.elem (coords), x.elem (coords)) * invs (j);
57
- x.elem (coords) -= xbar;
58
- }
67
+ // 1) main projection
68
+ project (x);
69
+ num = 0.0 ;
70
+ for (n = 0 ; n < N; ++n) {
71
+ num += std::abs (x[n] - x0[n]) / (1.0 + std::abs (x0[n])) * w[n];
59
72
}
60
-
61
- // Convergence check
62
- ratio = dot (abs (x - x0) / (1.0 + abs (x0)), w) * inv_sw;
73
+ ratio = num * inv_sw;
63
74
if (ratio < tol)
64
75
break ;
65
76
66
- // Acceleration every 5 iters
67
- if (iter > 5 && (iter % 5 ) == 0 ) {
77
+ // 2) acceleration every 5 iters
78
+ if (iter >= 5 && (iter % 5 ) == 0 ) {
68
79
Gx = x;
69
- // Second projection
70
- for (size_t k = 0 ; k < K; ++k) {
71
- field<uvec> &idxs = group_indices (k);
72
- vec &invs = group_inverse_weights (k);
73
- for (j = 0 ; j < idxs.n_elem ; ++j) {
74
- const uvec &coords = idxs (j);
75
- xbar = dot (w.elem (coords), Gx.elem (coords)) * invs (j);
76
- Gx.elem (coords) -= xbar;
77
- }
78
- }
80
+ project (Gx);
79
81
G2x = Gx;
80
-
81
- // Compute deltas
82
82
deltaG = G2x - x;
83
83
delta2 = G2x - 2.0 * x + x0;
84
84
ssq = dot (delta2, delta2);
@@ -92,17 +92,17 @@ void center_variables_(mat &V, const vec &w, const list &klist,
92
92
}
93
93
}
94
94
95
- // SSR check
96
- if (iter == iter_check_ssr && iter > 0 ) {
95
+ // 3) SSR‐based early exit
96
+ if (iter == isr && iter > 0 ) {
97
97
check_user_interrupt ();
98
- iter_check_ssr += iter_check_ssr0 ;
98
+ isr += isr0 ;
99
99
ssr = dot (x % x, w) * inv_sw;
100
- if (fabs (ssr - ssr0) / (1.0 + fabs (ssr0)) < tol)
100
+ if (std:: fabs (ssr - ssr0) / (1.0 + std:: fabs (ssr0)) < tol)
101
101
break ;
102
102
ssr0 = ssr;
103
103
}
104
104
105
- // Early exit
105
+ // 4) early exit
106
106
if (iter > 3 && (ratio0 / ratio) < 1.1 && ratio < tol * 20 )
107
107
break ;
108
108
ratio0 = ratio;
@@ -114,8 +114,8 @@ void center_variables_(mat &V, const vec &w, const list &klist,
114
114
115
115
[[cpp11::register ]] doubles_matrix<>
116
116
center_variables_r_ (const doubles_matrix<> &V_r, const doubles &w_r,
117
- const list &klist, const double & tol, const int & max_iter,
118
- const int & iter_interrupt, const int & iter_ssr) {
117
+ const list &klist, const double tol, const int max_iter,
118
+ const int iter_interrupt, const int iter_ssr) {
119
119
mat V = as_mat (V_r);
120
120
center_variables_ (V, as_col (w_r), klist, tol, max_iter, iter_interrupt,
121
121
iter_ssr);
0 commit comments