Skip to content

Commit a55c32d

Browse files
committed
ENH: nsa flow
1 parent c74c075 commit a55c32d

File tree

1 file changed

+111
-59
lines changed

1 file changed

+111
-59
lines changed

vignettes/nsa_flow.Rmd

Lines changed: 111 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ Y0_toy <- matrix(runif(12, 0, 1), 4, 3)
327327
328328
# Apply NSA-Flow with balanced weights
329329
# X0_toy=X0_toy/norm(X0_toy, "F")*0.1 # Normalize
330-
omega_default = 0.05
330+
omega_default = 0.5
331331
# if ( ! exists("ini_default") )
332332
lropts=c('armijo', 'armijo_aggressive', 'exponential', 'linear', 'random', 'adaptive', 'momentum_boost', 'entropy', 'poly_decay', 'bayes')
333-
ini_default = 'armijo' #
334-
optype='lars' # for torch backend
333+
ini_default = 'bayes' #
334+
optype='asgd' # for torch backend
335335
def_ret = "soft_polar"
336-
nsa_default <- function(Y0, w = omega_default,
337-
X0 = NULL, init=ini_default, agg=0.8, verbose = FALSE ) {
336+
nsa_default <- function(Y0, w = omega_default,
337+
X0 = NULL, init=ini_default, agg=0.5, verbose = FALSE ) {
338338
nsa_flow_autograd(
339339
Y0 = Y0,
340340
X0 = X0,
@@ -346,7 +346,7 @@ nsa_default <- function(Y0, w = omega_default,
346346
apply_nonneg = TRUE,
347347
tol = 1e-6,
348348
window_size=10,
349-
fidelity_type = "symmetric",
349+
fidelity_type = "scale_invariant", #"symmetric",
350350
orth_type = "scale_invariant",
351351
lr_strategy = init,
352352
aggression = agg,
@@ -355,26 +355,7 @@ nsa_default <- function(Y0, w = omega_default,
355355
)
356356
}
357357
#
358-
# res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
359-
#
360-
res_toy = nsa_flow_autograd(
361-
Y0 = X0_toy,
362-
X0 = true_Y,
363-
w = omega_default,
364-
retraction = def_ret,
365-
max_iter = 500,
366-
verbose = FALSE,
367-
seed = 42,
368-
apply_nonneg = TRUE,
369-
tol = 1e-6,
370-
window_size=5,
371-
fidelity_type = "scale_invariant",
372-
orth_type = "scale_invariant",
373-
lr_strategy = 'armijo',
374-
optimizer = optype,
375-
plot = TRUE
376-
)
377-
358+
res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
378359
# Visualize
379360
380361
library(ggplot2)
@@ -592,13 +573,13 @@ X0 = generate_synth_data( p, k, corrval=0.35, noise=0.05, sparse_prob=0.0, inclu
592573
###
593574
w_seq <- c( 0.005, 0.05, 0.1, 0.2, 0.5 )
594575
w_seq <- c( 0.001, 0.005, 0.01, 0.05, 0.25 )
595-
w_seq <- c( 0.1, 0.25, 0.5, 0.75, 0.9 )
576+
w_seq <- c( 0.001, 0.25, 0.5, 0.75, 0.9 )
596577
mytit = paste0("w = ", round(w_seq,3))
597578
mats <- list()
598579
convergeplots <- list()
599580
for(i in seq_along(w_seq)) {
600581
w_val <- w_seq[i]
601-
res_soft_w <- nsa_default( X0, w = w_val, verbose = FALSE )
582+
res_soft_w <- nsa_default( X0, w = w_val, agg=0.95, verbose = FALSE )
602583
mytit[i] <- paste0("w = ", round(w_val, 3), ', orth = ',
603584
round(invariant_orthogonality_defect(res_soft_w$Y),4), ', w.spar = ',
604585
1.0-round(sum(res_soft_w$Y/max(res_soft_w$Y) > quantile(res_soft_w$Y,0.1))/length(res_soft_w$Y),3))
@@ -622,7 +603,8 @@ for(i in seq_along(mats)) {
622603
}
623604
grid.arrange(grobs = lapply(swplots, function(x) x$gtable), ncol = 3)
624605
if ( length(convergeplots) >=4 ) {
625-
grid.arrange(grobs=convergeplots[c(1,2,3,5)], top='Convergence Plots for Different w Values', ncol=2 )
606+
grid.arrange(grobs=convergeplots[c(1,2,3,5)],
607+
top='Convergence Plots for Different w Values', ncol=2 )
626608
}
627609
628610
# darkk #
@@ -932,8 +914,11 @@ explained_variance_ratio_by_orthonormalizing <- function(X = NULL, Y, use = c("q
932914
X=generate_synth_data( p=100, k=20, corrval=0.35)$Y0
933915
nembed = 4
934916
# --- Compute results for both methods ---
935-
res_soft <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5,proximal_type='basic', verbose = FALSE )
936-
res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5, proximal_type='nsa_flow', nsa_flow_fn = nsa_default, verbose = FALSE )
917+
res_soft <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200,
918+
nsa_w = omega_default, tol = 1e-5, proximal_type='basic', verbose = FALSE )
919+
res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200,
920+
nsa_w = omega_default, tol = 1e-5, proximal_type='nsa_flow',
921+
nsa_flow_fn = nsa_default, verbose = FALSE )
937922
938923
939924
@@ -1073,13 +1058,12 @@ golub_scaled_ss <- golub_scaled[, ss]
10731058
pca_std <- prcomp(golub_scaled_ss, rank. = myk)
10741059
proj_std <- pca_std$x
10751060
1076-
res_basic <- nsa_flow_pca(golub_scaled_ss, myk,lambda = 0.1, alpha = 0.001,
1061+
res_basic <- nsa_flow_pca(golub_scaled_ss, myk,lambda = 0.1, alpha = 0.01,
10771062
max_iter = mxit, proximal_type = "basic", tol = 1e-5,
10781063
nsa_w = omega_default, verbose = F)
1079-
res_nns <- nsa_flow_pca(golub_scaled_ss, myk, lambda = 0.1, alpha = 0.001,
1064+
res_nns <- nsa_flow_pca(golub_scaled_ss, myk, lambda = 0.1, alpha = 0.01,
10801065
max_iter = mxit, proximal_type = "nsa_flow", tol = 1e-5,
10811066
nsa_w = omega_default, nsa_flow_fn = nsa_default, verbose = FALSE)
1082-
10831067
## --- Core Metrics ------------------------------------------------------------
10841068
metrics_pca_g <- compute_core_metrics(pca_std$rotation, golub_scaled_ss)
10851069
metrics_basic_g <- compute_core_metrics(res_basic$Y, golub_scaled_ss)
@@ -1313,7 +1297,7 @@ regions <- c(regions, paste0('left_', subcortical), paste0('right_', subcortical
13131297
stopifnot(length(regions) == p)
13141298
X <- as.data.frame(X_data)
13151299
colnames(X) <- regions
1316-
1300+
#############################
13171301
13181302
library(readr)
13191303
adnifn="../../multidisorder/data/ppmiadni_filtered.csv"
@@ -1647,7 +1631,8 @@ for (net in 1:netrank) {
16471631

16481632

16491633

1650-
```{r enhanced-stats-dx,echo=FALSE,fig.width=7,fig.cap="Population-level distribution of NSA-based embeddings for different diagnostic groups. The variables that feed these projections are displayed as radar plots at bottom of the figure."}
1634+
```{r enhanced-stats-dx-0,echo=FALSE,fig.width=7,fig.cap="Multi-class random forest classification results: NSA vs PCA. Two runs of 4-fold cross-validation with results measured by AUC."}
1635+
library( randomForest )
16511636
ppmiadni$diagnosis = as.character(ppmiadni$diagnosis)
16521637
ppmiadni$diagnosis[ppmiadni$diagnosis %in% "SMC"]=NA
16531638
ppmiadni$diagnosis = factor( ppmiadni$diagnosis, levels=c("CN","MCI","AD") )
@@ -1666,34 +1651,26 @@ if (!all(c("diagnosis", covars) %in% names(ppmiadni))) {
16661651
stop("ppmiadni must contain 'diagnosis' and all variables in covars.")
16671652
}
16681653
1654+
set.seed(1) # For reproducibility
16691655
# Initialize data frame for logistic results and ROC lists
16701656
log_results_df <- data.frame()
16711657
roc_nsa_list <- list()
16721658
roc_pca_list <- list()
1673-
1674-
1659+
mywws = seq(0.0, 0.5, by = 0.05 )
16751660
# Loop over weight values
1676-
for (ww in seq(0.1, 0.9, by = 0.2 )) {
1661+
for ( run in 1:2 )
1662+
for (ww in mywws ) {
16771663
# Run NSA (assumes nsa_default is a custom function)
1678-
M_nsa <- tryCatch(
1679-
nsa_default(Y0 = Y0_pca, w = ww),
1680-
error = function(e) {
1681-
cat("Error in nsa_default for weight", ww, ":", e$message, "\n")
1682-
return(NULL)
1683-
}
1684-
)
1685-
if (is.null(M_nsa)) {
1686-
cat("Skipping weight", ww, "due to NSA failure.\n")
1687-
next
1688-
}
1664+
M_nsa <- nsa_default(Y0 = Y0_pca, w = ww)
16891665
Ymat <- M_nsa$Y
16901666
proj_nsa <- as.matrix(X) %*% Ymat
16911667
proj_pca <- as.matrix(X) %*% Y0_pca
16921668
colnames(proj_nsa) <- paste0("nsa", 1:ncol(proj_nsa))
16931669
colnames(proj_pca) <- paste0("pca", 1:ncol(proj_pca))
16941670
16951671
# Combine data with covariates and projections
1696-
temp <- cbind(ppmiadni[, c("diagnosis", covars)], proj_nsa, proj_pca)
1672+
temp <- cbind(ppmiadni[, c("diagnosis", "AGE", "PTGENDER")], proj_nsa, proj_pca)
1673+
temp$AGE = antsrimpute( temp$AGE )
16971674
dx2 <- "AD"
16981675
dx1 <- "MCI"
16991676
dx0 <- "CN"
@@ -1704,7 +1681,6 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17041681
17051682
# 4-fold cross-validation
17061683
nfold <- 4
1707-
set.seed(123) # For reproducibility
17081684
folds <- createFolds(temp$diagnosis, k = nfold, list = TRUE, returnTrain = FALSE)
17091685
auc_nsa <- auc_pca <- auc_random <- numeric(nfold)
17101686
@@ -1725,10 +1701,12 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17251701
auc_pca[f] <- NA
17261702
next
17271703
}
1728-
1704+
mlfun = multinom; ptype="probs"
1705+
mlfun = randomForest; ptype='prob'
1706+
# mlfun = imbalanced
17291707
# Base model (AGE and PTGENDER)
17301708
base_mod <- tryCatch(
1731-
multinom(as.formula("diagnosis ~ AGE + PTGENDER"), data = train, trace = FALSE),
1709+
mlfun(as.formula("diagnosis ~ AGE + PTGENDER"), data = train, trace = FALSE),
17321710
error = function(e) {
17331711
cat("Error in base model for fold", f, "weight", ww, ":", e$message, "\n")
17341712
return(NULL)
@@ -1739,14 +1717,14 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17391717
# NSA full model
17401718
full_nsa_form <- paste("diagnosis ~", covarsbin, "+", paste(colnames(proj_nsa), collapse = "+"))
17411719
full_nsa <- tryCatch(
1742-
multinom(as.formula(full_nsa_form), data = train, trace = FALSE),
1720+
mlfun(as.formula(full_nsa_form), data = train, trace = FALSE),
17431721
error = function(e) {
17441722
cat("Error in NSA model for fold", f, "weight", ww, ":", e$message, "\n")
17451723
return(NULL)
17461724
}
17471725
)
17481726
if (!is.null(full_nsa)) {
1749-
preds_nsa <- predict(full_nsa, test, type = "probs")
1727+
preds_nsa <- predict(full_nsa, test, type = ptype )
17501728
roc_nsa <- tryCatch(
17511729
multiclass.roc(test$diagnosis, preds_nsa, levels = levels(test$diagnosis), quiet = TRUE),
17521730
error = function(e) {
@@ -1762,14 +1740,14 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17621740
# PCA full model
17631741
full_pca_form <- paste("diagnosis ~", covarsbin, "+", paste(colnames(proj_pca), collapse = "+"))
17641742
full_pca <- tryCatch(
1765-
multinom(as.formula(full_pca_form), data = train, trace = FALSE),
1743+
mlfun(as.formula(full_pca_form), data = train, trace = FALSE),
17661744
error = function(e) {
17671745
cat("Error in PCA model for fold", f, "weight", ww, ":", e$message, "\n")
17681746
return(NULL)
17691747
}
17701748
)
17711749
if (!is.null(full_pca)) {
1772-
preds_pca <- predict(full_pca, test, type = "probs")
1750+
preds_pca <- predict(full_pca, test, type = ptype )
17731751
roc_pca <- tryCatch(
17741752
multiclass.roc(test$diagnosis, preds_pca, levels = levels(test$diagnosis), quiet = TRUE),
17751753
error = function(e) {
@@ -1800,6 +1778,81 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
18001778
))
18011779
}
18021780
1781+
1782+
# Load required libraries
1783+
library(ggplot2)
1784+
library(dplyr)
1785+
library(tidyr)
1786+
library(gridExtra)
1787+
library(scales)
1788+
# Add experiment label (two runs)
1789+
log_results_df$experiment <- rep(c("Run 1", "Run 2"), each = length(mywws))
1790+
1791+
# === Long format for AUC lines ===
1792+
auc_long <- log_results_df %>%
1793+
dplyr::select(w, experiment, auc_nsa, auc_pca) %>%
1794+
pivot_longer(cols = c(auc_nsa, auc_pca), names_to = "method", values_to = "auc") %>%
1795+
mutate(method = ifelse(method == "auc_nsa", "NSA-ICA", "PCA"))
1796+
1797+
# === Plot 1: AUC vs w (NSA-ICA vs PCA) ===
1798+
p1 <- ggplot(auc_long, aes(x = w, y = auc, color = method, linetype = experiment)) +
1799+
geom_line(size = 1.2) +
1800+
geom_point(size = 2.5, shape = 21, fill = "white", stroke = 1.2) +
1801+
scale_color_manual(values = c("NSA-ICA" = "#D55E00", "PCA" = "#0072B2")) +
1802+
scale_linetype_manual(values = c("Run 1" = "solid", "Run 2" = "dashed")) +
1803+
geom_hline(yintercept = mean(data$random_accuracy), linetype = "dotted", color = "gray40", size = 1) +
1804+
annotate("text", x = 0.45, y = mean(data$random_accuracy) + 0.005,
1805+
label = "Random Chance (~0.406)", color = "gray40", size = 3.5, hjust = 0) +
1806+
labs(
1807+
title = "Cross-Validated AUC for Diagnosis Prediction",
1808+
subtitle = "NSA-ICA vs PCA across regularization strength (w)",
1809+
x = "NSA Regularization Weight (w)",
1810+
y = "Mean AUC (Cross-Validation)",
1811+
color = "Method",
1812+
linetype = "Experiment"
1813+
) +
1814+
theme_minimal(base_size = 8) +
1815+
theme(
1816+
plot.title = element_text(face = "bold", size = 8),
1817+
plot.subtitle = element_text(color = "gray50"),
1818+
legend.position = "top",
1819+
panel.grid.minor = element_blank(),
1820+
axis.line = element_line(color = "gray70"),
1821+
axis.ticks = element_line(color = "gray70")
1822+
) +
1823+
scale_y_continuous(limits = c(0.68, 0.79), breaks = seq(0.68, 0.79, 0.02))
1824+
1825+
# === Plot 2: AUC Difference (NSA - PCA) ===
1826+
p2 <- ggplot(log_results_df, aes(x = w, y = auc_diff, color = experiment)) +
1827+
geom_hline(yintercept = 0, linetype = "solid", color = "gray50") +
1828+
geom_line(size = 1.3) +
1829+
geom_point(size = 3, shape = 18) +
1830+
scale_color_manual(values = c("Run 1" = "#D55E00", "Run 2" = "#CC79A7")) +
1831+
labs(
1832+
title = "Performance Gain: NSA-ICA − PCA",
1833+
x = "NSA Regularization Weight (w)",
1834+
y = "ΔAUC (NSA-ICA − PCA)",
1835+
color = "Experiment"
1836+
) +
1837+
theme_minimal(base_size = 8) +
1838+
theme(
1839+
plot.title = element_text(face = "bold", size = 8),
1840+
legend.position = "none",
1841+
panel.grid.minor = element_blank(),
1842+
axis.line = element_line(color = "gray70"),
1843+
axis.ticks = element_line(color = "gray70")
1844+
) +
1845+
scale_y_continuous(labels = percent_format(accuracy = 0.1)) +
1846+
annotate("rect", xmin = -Inf, xmax = Inf, ymin = 0, ymax = Inf, fill = "green", alpha = 0.05)
1847+
1848+
# === Combine Plots ===
1849+
final_plot <- gridExtra::grid.arrange(p1, p2, ncol = 1, heights = c(1.8, 1))
1850+
print(final_plot)
1851+
```
1852+
1853+
1854+
```{r enhanced-stats-dx,echo=FALSE,fig.width=7,fig.cap="Population-level distribution of NSA-based embeddings for different diagnostic groups. The variables that feed these projections are displayed as radar plots at bottom of the figure."}
1855+
18031856
# Summary for logistic regression
18041857
log_summary <- log_results_df %>%
18051858
summarise(
@@ -1871,7 +1924,6 @@ for (comp in psel ) {
18711924
# grid.arrange( grobs = gglist, ncol = 2 , top="Distributions of Significant NSA Components by Diagnosis" )
18721925
18731926
# gt::gt(log_results_df,caption='Cross-Validated AUC Summary for Diagnosis Prediction: Full results.')
1874-
18751927
gt::gt(log_summary,caption='Cross-Validated AUC Summary for Diagnosis Prediction: Statistical Summary.')
18761928
18771929
# Plot results

0 commit comments

Comments
 (0)