Skip to content

Commit d65be44

Browse files
committed
using OpenMP safely
1 parent 9f48133 commit d65be44

11 files changed

+241
-163
lines changed

R/cpp11.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@ center_variables_r_ <- function(V_r, w_r, klist, tol, maxiter, interrupt_iter) {
44
.Call(`_capybara_center_variables_r_`, V_r, w_r, klist, tol, maxiter, interrupt_iter)
55
}
66

7+
felm_fit_ <- function(y_r, x_r, wt_r, control, k_list) {
8+
.Call(`_capybara_felm_fit_`, y_r, x_r, wt_r, control, k_list)
9+
}
10+
11+
feglm_fit_ <- function(beta_r, eta_r, y_r, x_r, wt_r, theta, family, control, k_list) {
12+
.Call(`_capybara_feglm_fit_`, beta_r, eta_r, y_r, x_r, wt_r, theta, family, control, k_list)
13+
}
14+
15+
feglm_offset_fit_ <- function(eta_r, y_r, offset_r, wt_r, family, control, k_list) {
16+
.Call(`_capybara_feglm_offset_fit_`, eta_r, y_r, offset_r, wt_r, family, control, k_list)
17+
}
18+
719
get_alpha_ <- function(p_r, klist, control) {
820
.Call(`_capybara_get_alpha_`, p_r, klist, control)
921
}
@@ -23,15 +35,3 @@ group_sums_var_ <- function(M_r, jlist) {
2335
group_sums_cov_ <- function(M_r, N_r, jlist) {
2436
.Call(`_capybara_group_sums_cov_`, M_r, N_r, jlist)
2537
}
26-
27-
feglm_fit_ <- function(beta_r, eta_r, y_r, x_r, wt_r, theta, family, control, k_list) {
28-
.Call(`_capybara_feglm_fit_`, beta_r, eta_r, y_r, x_r, wt_r, theta, family, control, k_list)
29-
}
30-
31-
feglm_offset_fit_ <- function(eta_r, y_r, offset_r, wt_r, family, control, k_list) {
32-
.Call(`_capybara_feglm_offset_fit_`, eta_r, y_r, offset_r, wt_r, family, control, k_list)
33-
}
34-
35-
felm_fit_ <- function(y_r, x_r, wt_r, control, k_list) {
36-
.Call(`_capybara_felm_fit_`, y_r, x_r, wt_r, control, k_list)
37-
}

src/00_main.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ using namespace cpp11;
88

99
// used across the scripts
1010

11+
#ifdef _OPENMP
12+
const size_t n_threads = omp_get_max_threads();
13+
#endif
14+
1115
void center_variables_(mat &V, const vec &w, const list &klist,
1216
const double &tol, const size_t &maxiter,
1317
const size_t &interrupt_iter);

src/01_center_variables.cpp

Lines changed: 0 additions & 87 deletions
This file was deleted.
File renamed without changes.

src/02_center_variables.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "00_main.h"
2+
3+
// Halperin method with aggressive precomputing and optimized memory use
4+
void center_variables_(mat &V, const vec &w, const list &klist,
5+
const double &tol, const size_t &maxiter,
6+
const size_t &interrupt_iter) {
7+
// Auxiliary variables (fixed)
8+
const size_t N = V.n_rows;
9+
const size_t P = V.n_cols;
10+
const size_t K = klist.size();
11+
const double inv_sw = 1.0 / accu(w);
12+
13+
// Auxiliary variables (storage)
14+
vec x(N), x0(N);
15+
field<field<uvec>> group_indices(K);
16+
field<vec> group_inverse_weights(K);
17+
18+
// Precompute indices and inverse weights once
19+
for (size_t k = 0; k < K; ++k) {
20+
const list &jlist = klist[k];
21+
size_t J = jlist.size();
22+
23+
field<uvec> indices(J);
24+
vec inverse_weights(J);
25+
26+
for (size_t j = 0; j < J; ++j) {
27+
indices(j) = as_uvec(as_cpp<integers>(jlist[j]));
28+
inverse_weights(j) = 1.0 / accu(w.elem(indices(j)));
29+
}
30+
31+
group_indices(k) = std::move(indices);
32+
group_inverse_weights(k) = std::move(inverse_weights);
33+
}
34+
35+
// Perform Halperin projections, parallelizing columns
36+
#ifdef _OPENMP
37+
#pragma omp parallel for schedule(static) num_threads(n_threads) private(x, x0)
38+
#endif
39+
for (size_t p = 0; p < P; ++p) {
40+
x = V.col(p);
41+
size_t interrupt = interrupt_iter;
42+
43+
for (size_t iter = 0; iter < maxiter; ++iter) {
44+
if (iter == interrupt) {
45+
// Only main thread checks for interrupts
46+
#ifdef _OPENMP
47+
if (omp_get_thread_num() == 0) {
48+
check_user_interrupt();
49+
}
50+
#else
51+
check_user_interrupt();
52+
#endif
53+
interrupt += interrupt_iter;
54+
}
55+
56+
x0 = x;
57+
58+
// Project onto group means
59+
for (size_t l = 0; l < K; ++l) {
60+
const size_t L = group_indices(l).size();
61+
for (size_t m = 0; m < L; ++m) {
62+
const uvec &coords = group_indices(l)(m);
63+
const double xbar =
64+
dot(w.elem(coords), x.elem(coords)) * group_inverse_weights(l)(m);
65+
x.elem(coords) -= xbar;
66+
}
67+
}
68+
69+
// Check convergence (correct placement)
70+
double ratio = dot(abs(x - x0) / (1.0 + abs(x0)), w) * inv_sw;
71+
if (ratio < tol) {
72+
break;
73+
}
74+
}
75+
76+
// Assign back at convergence
77+
V.col(p) = std::move(x);
78+
}
79+
}
80+
81+
[[cpp11::register]] doubles_matrix<> center_variables_r_(
82+
const doubles_matrix<> &V_r, const doubles &w_r, const list &klist,
83+
const double &tol, const int &maxiter, const int &interrupt_iter) {
84+
mat V = as_Mat(V_r);
85+
vec w = as_Col(w_r);
86+
center_variables_(V, w, klist, tol, maxiter, interrupt_iter);
87+
return as_doubles_matrix(std::move(V));
88+
}
89+
90+
// Kaczmarz demeaning
91+
// void center_variables_(mat &V, const vec &w, const list &klist,
92+
// const double &tol, const size_t &maxiter,
93+
// const size_t &interrupt_iter) {
94+
// // Auxiliary variables (fixed)
95+
// const size_t P = V.n_cols;
96+
// const size_t K = klist.size();
97+
// const double inv_sw = 1.0 / accu(w);
98+
99+
// // Auxiliary variables (storage)
100+
// size_t interrupt = static_cast<size_t>(interrupt_iter);
101+
// uvec coords;
102+
103+
// // Precompute group indices and weights parallelizing over groups
104+
// field<field<uvec>> group_indices(K);
105+
// field<vec> group_inverse_weights(K);
106+
107+
// #ifdef _OPENMP
108+
// #pragma omp parallel for schedule(static, n_threads)
109+
// #endif
110+
// for (size_t k = 0; k < K; ++k) {
111+
// const list &jlist = klist[k];
112+
// size_t J = jlist.size();
113+
114+
// field<uvec> indices(J);
115+
// vec inverse_weights(J);
116+
117+
// for (size_t j = 0; j < J; ++j) {
118+
// indices(j) = as_uvec(as_cpp<integers>(jlist[j]));
119+
// inverse_weights(j) = 1.0 / accu(w.elem(indices(j)));
120+
// }
121+
122+
// group_indices(k) = std::move(indices);
123+
// group_inverse_weights(k) = std::move(inverse_weights);
124+
// }
125+
126+
// // Kaczmarz iterations parallelizing over columns
127+
// #ifdef _OPENMP
128+
// #pragma omp parallel for schedule(static, n_threads)
129+
// #endif
130+
// for (size_t p = 0; p < P; ++p) {
131+
// for (size_t iter = 0; iter < maxiter; ++iter) {
132+
// if (iter == interrupt) {
133+
// check_user_interrupt();
134+
// interrupt += 1000;
135+
// }
136+
137+
// vec x = V.col(p);
138+
// vec x0 = x;
139+
// double ratio;
140+
141+
// for (size_t l = 0; l < K; ++l) {
142+
// size_t L = group_indices(l).size();
143+
// if (L == 0) continue;
144+
145+
// for (size_t m = 0; m < L; ++m) {
146+
// const uvec &coords = group_indices(l)(m);
147+
// double xbar =
148+
// dot(w.elem(coords), x.elem(coords)) / accu(w.elem(coords));
149+
// x.elem(coords) -= xbar;
150+
// }
151+
// }
152+
153+
// ratio = dot(abs(x - x0) / (1.0 + abs(x0)), w) * inv_sw;
154+
// if (ratio < tol) {
155+
// break;
156+
// }
157+
158+
// V.col(p) = x;
159+
// }
160+
// }
161+
// }

src/07_lm_fit.cpp renamed to src/03_lm_fit.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,10 @@
4242

4343
// Generate result list
4444

45-
writable::list out(4);
46-
out[0] = as_doubles(beta);
47-
out[1] = as_doubles(fitted);
48-
out[2] = as_doubles(w);
49-
out[3] = as_doubles_matrix(H);
50-
out.attr("names") = writable::strings(
51-
{"coefficients", "fitted.values", "weights", "hessian"});
52-
53-
return out;
45+
return writable::list(
46+
{"coefficients"_nm = as_doubles(std::move(beta)),
47+
"fitted.values"_nm = as_doubles(std::move(fitted)),
48+
"weights"_nm = as_doubles(std::move(w)),
49+
"hessian"_nm = as_doubles_matrix(std::move(H))
50+
});
5451
}

src/05_glm_fit.cpp renamed to src/04_glm_fit.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -307,26 +307,22 @@ bool valid_mu_(const vec &mu, const FamilyType &fam) {
307307

308308
// Generate result list
309309

310-
writable::list out(8);
311-
312-
out[0] = as_doubles(beta);
313-
out[1] = as_doubles(eta);
314-
out[2] = as_doubles(wt);
315-
out[3] = as_doubles_matrix(H);
316-
out[4] = writable::doubles({dev});
317-
out[5] = writable::doubles({null_dev});
318-
out[6] = writable::logicals({conv});
319-
out[7] = writable::integers({static_cast<int>(iter + 1)});
320-
321-
out.attr("names") =
322-
writable::strings({"coefficients", "eta", "weights", "hessian",
323-
"deviance", "null_deviance", "conv", "iter"});
310+
writable::list out({
311+
"coefficients"_nm = as_doubles(std::move(beta)),
312+
"eta"_nm = as_doubles(std::move(eta)),
313+
"weights"_nm = as_doubles(std::move(wt)),
314+
"hessian"_nm = as_doubles_matrix(std::move(H)),
315+
"deviance"_nm = writable::doubles({dev}),
316+
"null_deviance"_nm = writable::doubles({null_dev}),
317+
"conv"_nm = writable::logicals({conv}),
318+
"iter"_nm = writable::integers({static_cast<int>(iter + 1)})
319+
});
324320

325321
if (keep_mx) {
326322
mat x_cpp = as_Mat(x_r);
327323
center_variables_(x_cpp, w, k_list, center_tol, iter_center_max,
328324
iter_interrupt);
329-
out.push_back({"MX"_nm = as_doubles_matrix(x_cpp)});
325+
out.push_back({"MX"_nm = as_doubles_matrix(std::move(x_cpp))});
330326
}
331327

332328
return out;

src/06_glm_offset_fit.cpp renamed to src/05_glm_offset_fit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,5 @@ feglm_offset_fit_(const doubles &eta_r, const doubles &y_r,
9898
Myadj = Myadj - yadj;
9999
}
100100

101-
return as_doubles(eta);
101+
return as_doubles(std::move(eta));
102102
}

0 commit comments

Comments
 (0)