Skip to content

Commit a72e243

Browse files
committed
try parallelized centering
1 parent 8f3f8bf commit a72e243

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

dev/check_speed.r

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,11 @@ mark(
3434
fixest::fepois(form, data = d)$coefficients["rta"],
3535
iterations = 10L
3636
)
37+
38+
Rprof("capybara_profile.out")
39+
mod <- capybara::fepoisson(form, data = d)
40+
Rprof(NULL)
41+
42+
profvis::profvis({
43+
mod <- capybara::fepoisson(form, data = d)
44+
})

src/00_main.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ using namespace cpp11;
99

1010
// used across the scripts
1111

12-
void center_variables_(Mat<double> &V, const Col<double> &w,
13-
const list &klist, const double &tol,
14-
const int &maxiter);
12+
void center_variables_(Mat<double> &V, const Col<double> &w, const list &klist,
13+
const double &tol, const int &maxiter);
1514

1615
Col<double> solve_beta_(Mat<double> MX, const Mat<double> &MNU,
1716
const Col<double> &w);

src/01_center_variables.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#include "00_main.h"
23

34
// Method of alternating projections (Halperin)
@@ -11,9 +12,8 @@ void center_variables_(Mat<double> &V, const Col<double> &w, const list &klist,
1112
const double inv_sw = 1.0 / accu(w);
1213

1314
// Auxiliary variables (storage)
14-
size_t iter, j, k, p, J, interrupt_iter = 1000;
15+
size_t iter, j, k, J, interrupt_iter = 1000;
1516
double meanj, ratio;
16-
Col<double> x(N), x0(N);
1717

1818
// Precompute group indices and weights
1919
field<field<uvec>> group_indices(K);
@@ -30,16 +30,19 @@ void center_variables_(Mat<double> &V, const Col<double> &w, const list &klist,
3030
}
3131
}
3232

33-
// Halperin projections
34-
// #ifdef _OPENMP
35-
// #pragma omp parallel for schedule(dynamic) private(x, x0, iter, k, j,
36-
// meanj, J, ratio) #endif
37-
for (p = 0; p < P; ++p) {
33+
// Halperin projections
34+
#ifdef _OPENMP
35+
#pragma omp parallel for schedule(dynamic) private( \
36+
iter, k, j, J, meanj, ratio) shared(V, w, group_indices, group_weights)
37+
#endif
38+
for (size_t p = 0; p < P; ++p) {
3839
// Center each variable
39-
x = V.col(p);
40+
Col<double> x = V.col(p);
41+
Col<double> x0(N);
4042

4143
for (iter = 0; iter < I; ++iter) {
4244
if (iter == interrupt_iter) {
45+
#pragma omp critical
4346
check_user_interrupt();
4447
interrupt_iter += 1000;
4548
}
@@ -49,7 +52,7 @@ void center_variables_(Mat<double> &V, const Col<double> &w, const list &klist,
4952

5053
// Alternate between categories
5154
for (k = 0; k < K; ++k) {
52-
// Substract the weighted group means of category 'k'
55+
// Subtract the weighted group means of category 'k'
5356
J = group_indices(k).size();
5457
if (J == 0)
5558
continue; // Skip empty groups
@@ -64,15 +67,15 @@ void center_variables_(Mat<double> &V, const Col<double> &w, const list &klist,
6467

6568
// Break loop if convergence is reached
6669
ratio = accu(abs(x - x0) / (1.0 + abs(x0)) % w) * inv_sw;
67-
// ratio = norm(x - x0, 2) * inv_sw;
6870
if (ratio < tol)
6971
break;
7072
}
7173
V.col(p) = x;
7274
}
7375
}
7476

75-
[[cpp11::register]] doubles_matrix<> center_variables_r_(const doubles_matrix<> &V_r, const doubles &w_r,
77+
[[cpp11::register]] doubles_matrix<>
78+
center_variables_r_(const doubles_matrix<> &V_r, const doubles &w_r,
7679
const list &klist, const double &tol, const int &maxiter) {
7780
Mat<double> V = as_Mat(V_r);
7881
Col<double> w = as_Col(w_r);

0 commit comments

Comments
 (0)