Skip to content

Commit 5738441

Browse files
committed
some speedups
1 parent 94e3bd6 commit 5738441

File tree

10 files changed

+117
-117
lines changed

10 files changed

+117
-117
lines changed

R/feglm.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ feglm <- function(
165165
lhs <- NA # just to avoid global variable warning
166166
nobs_na <- NA
167167
nobs_full <- NA
168+
weights_vec <- NA
169+
weights_col <- NA
168170
model_frame_(data, formula, weights)
169171

170172
# Ensure that model response is in line with the chosen model ----
@@ -202,13 +204,13 @@ feglm <- function(
202204
# Extract weights if required ----
203205
if (is.null(weights)) {
204206
wt <- rep(1.0, nt)
205-
} else if (exists("weights_vec")) {
207+
} else if (!all(is.na(weights_vec))) {
206208
# Weights provided as vector
207209
wt <- weights_vec
208210
if (length(wt) != nrow(data)) {
209211
stop("Length of weights vector must equal number of observations.", call. = FALSE)
210212
}
211-
} else if (exists("weights_col")) {
213+
} else if (!all(is.na(weights_col))) {
212214
# Weights provided as formula - use the extracted column name
213215
wt <- data[[weights_col]]
214216
} else {

R/felm.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ felm <- function(formula = NULL, data = NULL, weights = NULL, control = NULL) {
112112
lhs <- NA # just to avoid global variable warning
113113
nobs_na <- NA
114114
nobs_full <- NA
115+
weights_vec <- NA
116+
weights_col <- NA
115117
model_frame_(data, formula, weights)
116118

117119
# Get names of the fixed effects variables and sort ----
@@ -143,13 +145,13 @@ felm <- function(formula = NULL, data = NULL, weights = NULL, control = NULL) {
143145
# Extract weights if required ----
144146
if (is.null(weights)) {
145147
wt <- rep(1.0, nt)
146-
} else if (exists("weights_vec")) {
148+
} else if (!all(is.na(weights_vec))) {
147149
# Weights provided as vector
148150
wt <- weights_vec
149151
if (length(wt) != nrow(data)) {
150152
stop("Length of weights vector must equal number of observations.", call. = FALSE)
151153
}
152-
} else if (exists("weights_col")) {
154+
} else if (!all(is.na(weights_col))) {
153155
# Weights provided as formula - use the extracted column name
154156
wt <- data[[weights_col]]
155157
} else {

src/00_main.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1+
#pragma once
2+
13
#include <armadillo.hpp>
24
#include <cpp11.hpp>
35
#include <cpp11armadillo.hpp>
46
#include <regex>
57
#include <unordered_map>
68

7-
using namespace arma;
8-
using namespace cpp11;
9+
// using namespace arma;
10+
using arma::field;
11+
using arma::mat;
12+
using arma::uvec;
13+
using arma::uword;
14+
using arma::vec;
15+
16+
// using namespace cpp11;
17+
using cpp11::doubles;
18+
using cpp11::doubles_matrix;
19+
using cpp11::integers;
20+
using cpp11::list;
921

1022
// used across the scripts
1123

src/01_linear_algebra.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
[[cpp11::register]] int check_linear_dependence_svd_(const doubles &y,
77
const doubles_matrix<> &x,
88
const int &p) {
9-
mat Y = as_mat(y);
10-
mat X = as_mat(x);
11-
X = join_rows(Y, X); // paste y and x together
12-
int r = rank(X);
9+
const mat Y = as_mat(y);
10+
const mat X = as_mat(x);
11+
mat Z(Y.n_rows, 1 + X.n_cols);
12+
Z.col(0) = Y;
13+
Z.cols(1, Z.n_cols - 1) = X;
14+
int r = rank(Z);
1315
if (r < p) {
1416
return 1;
1517
}
@@ -42,7 +44,6 @@ vec solve_beta_(mat MX, const mat &MNU, const vec &w) {
4244
const vec sqrt_w = sqrt(w);
4345

4446
MX.each_col() %= sqrt_w;
45-
mat WMNU = MNU.each_col() % sqrt_w;
4647

4748
mat XtX = MX.t() * MX;
4849
vec XtY = MX.t() * (MNU.each_col() % sqrt_w);

src/02_center_variables.cpp

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// 02_center_variables.cpp (refactored using Armadillo types)
12
#include "00_main.h"
23

34
// Method of alternating projections (Halperin)
@@ -12,108 +13,86 @@ void center_variables_(mat &V, const vec &w, const list &klist,
1213
const double inv_sw = 1.0 / accu(w);
1314

1415
// Auxiliary variables (storage)
15-
size_t iter, j, k, l, m, p, L, J,
16-
iter_check_interrupt = iter_check_interrupt0,
17-
iter_check_ssr = iter_check_ssr0;
18-
double xbar, ratio, ratio0, ssr, ssr0, vprod, ssq, coef;
19-
vec x(N), x0(N);
16+
size_t iter, j, k, p, J, iter_check_interrupt = iter_check_interrupt0,
17+
iter_check_ssr = iter_check_ssr0;
18+
double coef, xbar, ratio, ssr, ssq, ratio0, ssr0;
19+
vec x(N), x0(N), Gx(N), G2x(N), deltaG(N), delta2(N);
2020
field<field<uvec>> group_indices(K);
2121
field<vec> group_inverse_weights(K);
22-
23-
// Precompute group indices and weights
2422
for (k = 0; k < K; ++k) {
2523
const list &jlist = klist[k];
2624
J = jlist.size();
27-
28-
field<uvec> indices(J);
29-
vec inverse_weights(J);
30-
25+
field<uvec> idxs(J);
26+
vec invs(J);
3127
for (j = 0; j < J; ++j) {
32-
indices(j) = as_uvec(as_cpp<integers>(jlist[j]));
33-
inverse_weights(j) = 1.0 / accu(w.elem(indices(j)));
28+
idxs(j) = as_uvec(as_cpp<integers>(jlist[j]));
29+
;
30+
invs(j) = 1.0 / accu(w.elem(idxs(j)));
3431
}
35-
36-
group_indices(k) = indices;
37-
group_inverse_weights(k) = inverse_weights;
32+
group_indices(k) = idxs;
33+
group_inverse_weights(k) = invs;
3834
}
3935

40-
// Pre-allocate vectors for acceleration (outside the loop to avoid
41-
// reallocation)
42-
vec G_x(N), G2_x(N), delta_G_x(N), delta2_x(N);
43-
44-
// Halperin projections parallelizing over columns
4536
for (p = 0; p < P; ++p) {
4637
x = V.col(p);
47-
ratio0 = std::numeric_limits<double>::max();
48-
ssr0 = std::numeric_limits<double>::max();
38+
ratio0 = std::numeric_limits<double>::infinity();
39+
ssr0 = std::numeric_limits<double>::infinity();
4940

5041
for (iter = 0; iter < I; ++iter) {
51-
// Check for user interrupt less frequently
5242
if (iter == iter_check_interrupt) {
5343
check_user_interrupt();
5444
iter_check_interrupt += iter_check_interrupt0;
5545
}
5646

57-
x0 = x; // Save current x
58-
59-
// Apply the Halperin projection
60-
for (l = 0; l < K; ++l) {
61-
L = group_indices(l).size();
62-
if (L == 0)
63-
continue;
47+
x0 = x;
6448

65-
for (m = 0; m < L; ++m) {
66-
const uvec &coords = group_indices(l)(m);
67-
xbar =
68-
dot(w.elem(coords), x.elem(coords)) * group_inverse_weights(l)(m);
49+
// Halperin projection
50+
for (k = 0; k < K; ++k) {
51+
field<uvec> &idxs = group_indices(k);
52+
J = idxs.n_elem;
53+
vec &invs = group_inverse_weights(k);
54+
for (j = 0; j < J; ++j) {
55+
const uvec &coords = idxs(j);
56+
xbar = dot(w.elem(coords), x.elem(coords)) * invs(j);
6957
x.elem(coords) -= xbar;
7058
}
7159
}
7260

73-
// First convergence check
61+
// Convergence check
7462
ratio = dot(abs(x - x0) / (1.0 + abs(x0)), w) * inv_sw;
7563
if (ratio < tol)
7664
break;
7765

78-
// Apply acceleration less frequently - only every 5 iterations instead of
79-
// 3 This reduces overhead while still getting acceleration benefits
80-
if (iter > 5 && iter % 5 == 0) {
81-
G_x = x; // G(x) - the result after one projection
82-
83-
// Apply another projection to get G(G(x))
84-
for (l = 0; l < K; ++l) {
85-
L = group_indices(l).size();
86-
if (L == 0)
87-
continue;
88-
89-
for (m = 0; m < L; ++m) {
90-
const uvec &coords = group_indices(l)(m);
91-
xbar = dot(w.elem(coords), G_x.elem(coords)) *
92-
group_inverse_weights(l)(m);
93-
G_x.elem(coords) -= xbar;
66+
// Acceleration every 5 iters
67+
if (iter > 5 && (iter % 5) == 0) {
68+
Gx = x;
69+
// Second projection
70+
for (size_t k = 0; k < K; ++k) {
71+
field<uvec> &idxs = group_indices(k);
72+
vec &invs = group_inverse_weights(k);
73+
for (j = 0; j < idxs.n_elem; ++j) {
74+
const uvec &coords = idxs(j);
75+
xbar = dot(w.elem(coords), Gx.elem(coords)) * invs(j);
76+
Gx.elem(coords) -= xbar;
9477
}
9578
}
96-
G2_x = G_x; // G²(x)
97-
98-
// Irons & Tuck acceleration formula
99-
delta_G_x = G2_x - x;
100-
delta2_x = G2_x - 2 * x + x0;
101-
102-
ssq = dot(delta2_x, delta2_x);
103-
if (ssq > 1e-10) { // Add numerical stability threshold
104-
vprod = dot(delta_G_x, delta2_x);
105-
coef = vprod / ssq;
106-
107-
// Limit coefficient to prevent excessive extrapolation
108-
if (coef > 0 && coef < 2.0) {
109-
x = G2_x - coef * delta_G_x;
79+
G2x = Gx;
80+
81+
// Compute deltas
82+
deltaG = G2x - x;
83+
delta2 = G2x - 2.0 * x + x0;
84+
ssq = dot(delta2, delta2);
85+
if (ssq > 1e-10) {
86+
coef = dot(deltaG, delta2) / ssq;
87+
if (coef > 0.0 && coef < 2.0) {
88+
x = G2x - coef * deltaG;
11089
} else {
111-
x = G2_x; // Use G2_x if coefficient is out of bounds
90+
x = G2x;
11291
}
11392
}
11493
}
11594

116-
// Check SSR improvement less frequently
95+
// SSR check
11796
if (iter == iter_check_ssr && iter > 0) {
11897
check_user_interrupt();
11998
iter_check_ssr += iter_check_ssr0;
@@ -123,15 +102,13 @@ void center_variables_(mat &V, const vec &w, const list &klist,
123102
ssr0 = ssr;
124103
}
125104

126-
// Early stopping based on ratio improvement
127-
if (iter > 3 && ratio0 / ratio < 1.1 && ratio < tol * 20) {
105+
// Early exit
106+
if (iter > 3 && (ratio0 / ratio) < 1.1 && ratio < tol * 20)
128107
break;
129-
}
130-
131108
ratio0 = ratio;
132109
}
133110

134-
V.col(p) = std::move(x);
111+
V.col(p) = x;
135112
}
136113
}
137114

@@ -140,7 +117,7 @@ center_variables_r_(const doubles_matrix<> &V_r, const doubles &w_r,
140117
const list &klist, const double &tol, const int &max_iter,
141118
const int &iter_interrupt, const int &iter_ssr) {
142119
mat V = as_mat(V_r);
143-
vec w = as_col(w_r);
144-
center_variables_(V, w, klist, tol, max_iter, iter_interrupt, iter_ssr);
120+
center_variables_(V, as_col(w_r), klist, tol, max_iter, iter_interrupt,
121+
iter_ssr);
145122
return as_doubles_matrix(V);
146123
}

src/03_lm_fit.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
const list &k_list) {
77
// Type conversion
88

9-
vec y = as_Col(y_r);
109
mat X = as_Mat(x_r);
11-
vec MNU = vec(y.n_elem, fill::zeros);
12-
vec w = as_Col(wt_r);
10+
const vec y = as_Col(y_r);
11+
const vec w = as_Col(wt_r);
1312

1413
// Auxiliary variables (fixed)
1514

@@ -20,23 +19,35 @@
2019

2120
// Auxiliary variables (storage)
2221

23-
mat MX, H;
22+
mat H(X.n_cols, X.n_cols);
23+
vec MNU(y.n_elem);
2424

2525
// Center variables
2626

27-
MNU += y;
28-
center_variables_(MNU, w, k_list, center_tol, iter_center_max, iter_interrupt,
29-
iter_ssr);
30-
center_variables_(X, w, k_list, center_tol, iter_center_max, iter_interrupt,
31-
iter_ssr);
27+
if (k_list.size() > 0) {
28+
// Initial response + centering for fixed effects
29+
MNU = y;
30+
center_variables_(MNU, w, k_list, center_tol, iter_center_max,
31+
iter_interrupt, iter_ssr);
32+
center_variables_(X, w, k_list, center_tol, iter_center_max, iter_interrupt,
33+
iter_ssr);
34+
} else {
35+
// No fixed effects
36+
MNU = vec(y.n_elem, fill::zeros);
37+
}
3238

3339
// Solve the normal equations
3440

3541
vec beta = solve_beta_(X, MNU, w);
3642

3743
// Fitted values
3844

39-
vec fitted = y - MNU + X * beta;
45+
vec fitted;
46+
if (k_list.size() > 0) {
47+
fitted = y - MNU + X * beta;
48+
} else {
49+
fitted = X * beta;
50+
}
4051

4152
// Recompute Hessian
4253

src/04_glm_fit.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,12 @@ vec variance_(const vec &mu, const double &theta,
248248
const std::string &family,
249249
const list &control, const list &k_list) {
250250
// Type conversion
251+
mat MX = as_Mat(x_r);
251252
vec beta = as_Col(beta_r);
252253
vec eta = as_Col(eta_r);
253-
vec y = as_Col(y_r);
254-
mat MX = as_Mat(x_r);
254+
const vec y = as_Col(y_r);
255255
vec MNU = vec(y.n_elem, fill::zeros);
256-
vec wt = as_Col(wt_r);
256+
const vec wt = as_Col(wt_r);
257257

258258
// Auxiliary variables (fixed)
259259

@@ -362,11 +362,6 @@ vec variance_(const vec &mu, const double &theta,
362362
stop("Algorithm did not converge.");
363363
}
364364

365-
// Update weights and dependent variable
366-
367-
mu_eta = mu_eta_(eta, family_type);
368-
w = (wt % square(mu_eta)) / variance_(mu, theta, family_type);
369-
370365
// Compute Hessian
371366

372367
H = crossprod_(MX, w);

src/05_glm_offset_fit.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ feglm_offset_fit_(const doubles &eta_r, const doubles &y_r,
99

1010
vec eta = as_Col(eta_r);
1111
vec y = as_Col(y_r);
12-
vec offset = as_Col(offset_r);
12+
const vec offset = as_Col(offset_r);
1313
vec Myadj = vec(y.n_elem, fill::zeros);
14-
vec wt = as_Col(wt_r);
14+
const vec wt = as_Col(wt_r);
1515

1616
// Auxiliary variables (fixed)
1717

@@ -31,7 +31,7 @@ feglm_offset_fit_(const doubles &eta_r, const doubles &y_r,
3131
vec mu = link_inv_(eta, family_type);
3232
double dev = dev_resids_(y, mu, 0.0, wt, family_type);
3333

34-
const int n = y.n_elem;
34+
const size_t n = y.n_elem;
3535
vec mu_eta(n), yadj(n), w(n);
3636

3737
bool dev_crit, val_crit, imp_crit;

0 commit comments

Comments
 (0)