@@ -327,14 +327,14 @@ Y0_toy <- matrix(runif(12, 0, 1), 4, 3)
327327
328328# Apply NSA-Flow with balanced weights
329329# X0_toy=X0_toy/norm(X0_toy, "F")*0.1 # Normalize
330- omega_default = 0.05
330+ omega_default = 0.5
331331# if ( ! exists("ini_default") )
332332lropts=c('armijo', 'armijo_aggressive', 'exponential', 'linear', 'random', 'adaptive', 'momentum_boost', 'entropy', 'poly_decay', 'bayes')
333- ini_default = 'armijo ' #
334- optype='lars ' # for torch backend
333+ ini_default = 'bayes ' #
334+ optype='asgd ' # for torch backend
335335def_ret = "soft_polar"
336- nsa_default <- function(Y0, w = omega_default,
337- X0 = NULL, init=ini_default, agg=0.8 , verbose = FALSE ) {
336+ nsa_default <- function(Y0, w = omega_default,
337+ X0 = NULL, init=ini_default, agg=0.5 , verbose = FALSE ) {
338338 nsa_flow_autograd(
339339 Y0 = Y0,
340340 X0 = X0,
@@ -346,7 +346,7 @@ nsa_default <- function(Y0, w = omega_default,
346346 apply_nonneg = TRUE,
347347 tol = 1e-6,
348348 window_size=10,
349- fidelity_type = "symmetric",
349+ fidelity_type = "scale_invariant", #" symmetric",
350350 orth_type = "scale_invariant",
351351 lr_strategy = init,
352352 aggression = agg,
@@ -355,26 +355,7 @@ nsa_default <- function(Y0, w = omega_default,
355355 )
356356}
357357#
358- # res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
359- #
360- res_toy = nsa_flow_autograd(
361- Y0 = X0_toy,
362- X0 = true_Y,
363- w = omega_default,
364- retraction = def_ret,
365- max_iter = 500,
366- verbose = FALSE,
367- seed = 42,
368- apply_nonneg = TRUE,
369- tol = 1e-6,
370- window_size=5,
371- fidelity_type = "scale_invariant",
372- orth_type = "scale_invariant",
373- lr_strategy = 'armijo',
374- optimizer = optype,
375- plot = TRUE
376- )
377-
358+ res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
378359# Visualize
379360
380361library(ggplot2)
@@ -592,13 +573,13 @@ X0 = generate_synth_data( p, k, corrval=0.35, noise=0.05, sparse_prob=0.0, inclu
592573###
593574w_seq <- c( 0.005, 0.05, 0.1, 0.2, 0.5 )
594575w_seq <- c( 0.001, 0.005, 0.01, 0.05, 0.25 )
595- w_seq <- c( 0.1 , 0.25, 0.5, 0.75, 0.9 )
576+ w_seq <- c( 0.001 , 0.25, 0.5, 0.75, 0.9 )
596577mytit = paste0("w = ", round(w_seq,3))
597578mats <- list()
598579convergeplots <- list()
599580for(i in seq_along(w_seq)) {
600581 w_val <- w_seq[i]
601- res_soft_w <- nsa_default( X0, w = w_val, verbose = FALSE )
582+ res_soft_w <- nsa_default( X0, w = w_val, agg=0.95, verbose = FALSE )
602583 mytit[i] <- paste0("w = ", round(w_val, 3), ', orth = ',
603584 round(invariant_orthogonality_defect(res_soft_w$Y),4), ', w.spar = ',
604585 1.0-round(sum(res_soft_w$Y/max(res_soft_w$Y) > quantile(res_soft_w$Y,0.1))/length(res_soft_w$Y),3))
@@ -622,7 +603,8 @@ for(i in seq_along(mats)) {
622603}
623604grid.arrange(grobs = lapply(swplots, function(x) x$gtable), ncol = 3)
624605if ( length(convergeplots) >=4 ) {
625- grid.arrange(grobs=convergeplots[c(1,2,3,5)], top='Convergence Plots for Different w Values', ncol=2 )
606+ grid.arrange(grobs=convergeplots[c(1,2,3,5)],
607+ top='Convergence Plots for Different w Values', ncol=2 )
626608}
627609
628610# darkk #
@@ -932,8 +914,11 @@ explained_variance_ratio_by_orthonormalizing <- function(X = NULL, Y, use = c("q
932914X=generate_synth_data( p=100, k=20, corrval=0.35)$Y0
933915nembed = 4
934916# --- Compute results for both methods ---
935- res_soft <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5,proximal_type='basic', verbose = FALSE )
936- res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5, proximal_type='nsa_flow', nsa_flow_fn = nsa_default, verbose = FALSE )
917+ res_soft <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200,
918+ nsa_w = omega_default, tol = 1e-5, proximal_type='basic', verbose = FALSE )
919+ res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200,
920+ nsa_w = omega_default, tol = 1e-5, proximal_type='nsa_flow',
921+ nsa_flow_fn = nsa_default, verbose = FALSE )
937922
938923
939924
@@ -1073,13 +1058,12 @@ golub_scaled_ss <- golub_scaled[, ss]
10731058pca_std <- prcomp(golub_scaled_ss, rank. = myk)
10741059proj_std <- pca_std$x
10751060
1076- res_basic <- nsa_flow_pca(golub_scaled_ss, myk,lambda = 0.1, alpha = 0.001 ,
1061+ res_basic <- nsa_flow_pca(golub_scaled_ss, myk,lambda = 0.1, alpha = 0.01 ,
10771062 max_iter = mxit, proximal_type = "basic", tol = 1e-5,
10781063 nsa_w = omega_default, verbose = F)
1079- res_nns <- nsa_flow_pca(golub_scaled_ss, myk, lambda = 0.1, alpha = 0.001 ,
1064+ res_nns <- nsa_flow_pca(golub_scaled_ss, myk, lambda = 0.1, alpha = 0.01 ,
10801065 max_iter = mxit, proximal_type = "nsa_flow", tol = 1e-5,
10811066 nsa_w = omega_default, nsa_flow_fn = nsa_default, verbose = FALSE)
1082-
10831067## --- Core Metrics ------------------------------------------------------------
10841068metrics_pca_g <- compute_core_metrics(pca_std$rotation, golub_scaled_ss)
10851069metrics_basic_g <- compute_core_metrics(res_basic$Y, golub_scaled_ss)
@@ -1313,7 +1297,7 @@ regions <- c(regions, paste0('left_', subcortical), paste0('right_', subcortical
13131297stopifnot(length(regions) == p)
13141298X <- as.data.frame(X_data)
13151299colnames(X) <- regions
1316-
1300+ #############################
13171301
13181302library(readr)
13191303adnifn="../../multidisorder/data/ppmiadni_filtered.csv"
@@ -1647,7 +1631,8 @@ for (net in 1:netrank) {
16471631
16481632
16491633
1650- ``` {r enhanced-stats-dx,echo=FALSE,fig.width=7,fig.cap="Population-level distribution of NSA-based embeddings for different diagnostic groups. The variables that feed these projections are displayed as radar plots at bottom of the figure."}
1634+ ``` {r enhanced-stats-dx-0,echo=FALSE,fig.width=7,fig.cap="Multi-class random forest classification results: NSA vs PCA. Two runs of 4-fold cross-validation with results measured by AUC."}
1635+ library( randomForest )
16511636ppmiadni$diagnosis = as.character(ppmiadni$diagnosis)
16521637ppmiadni$diagnosis[ppmiadni$diagnosis %in% "SMC"]=NA
16531638ppmiadni$diagnosis = factor( ppmiadni$diagnosis, levels=c("CN","MCI","AD") )
@@ -1666,34 +1651,26 @@ if (!all(c("diagnosis", covars) %in% names(ppmiadni))) {
16661651 stop("ppmiadni must contain 'diagnosis' and all variables in covars.")
16671652}
16681653
1654+ set.seed(1) # For reproducibility
16691655# Initialize data frame for logistic results and ROC lists
16701656log_results_df <- data.frame()
16711657roc_nsa_list <- list()
16721658roc_pca_list <- list()
1673-
1674-
1659+ mywws = seq(0.0, 0.5, by = 0.05 )
16751660# Loop over weight values
1676- for (ww in seq(0.1, 0.9, by = 0.2 )) {
1661+ for ( run in 1:2 )
1662+ for (ww in mywws ) {
16771663 # Run NSA (assumes nsa_default is a custom function)
1678- M_nsa <- tryCatch(
1679- nsa_default(Y0 = Y0_pca, w = ww),
1680- error = function(e) {
1681- cat("Error in nsa_default for weight", ww, ":", e$message, "\n")
1682- return(NULL)
1683- }
1684- )
1685- if (is.null(M_nsa)) {
1686- cat("Skipping weight", ww, "due to NSA failure.\n")
1687- next
1688- }
1664+ M_nsa <- nsa_default(Y0 = Y0_pca, w = ww)
16891665 Ymat <- M_nsa$Y
16901666 proj_nsa <- as.matrix(X) %*% Ymat
16911667 proj_pca <- as.matrix(X) %*% Y0_pca
16921668 colnames(proj_nsa) <- paste0("nsa", 1:ncol(proj_nsa))
16931669 colnames(proj_pca) <- paste0("pca", 1:ncol(proj_pca))
16941670
16951671 # Combine data with covariates and projections
1696- temp <- cbind(ppmiadni[, c("diagnosis", covars)], proj_nsa, proj_pca)
1672+ temp <- cbind(ppmiadni[, c("diagnosis", "AGE", "PTGENDER")], proj_nsa, proj_pca)
1673+ temp$AGE = antsrimpute( temp$AGE )
16971674 dx2 <- "AD"
16981675 dx1 <- "MCI"
16991676 dx0 <- "CN"
@@ -1704,7 +1681,6 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17041681
17051682 # 4-fold cross-validation
17061683 nfold <- 4
1707- set.seed(123) # For reproducibility
17081684 folds <- createFolds(temp$diagnosis, k = nfold, list = TRUE, returnTrain = FALSE)
17091685 auc_nsa <- auc_pca <- auc_random <- numeric(nfold)
17101686
@@ -1725,10 +1701,12 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17251701 auc_pca[f] <- NA
17261702 next
17271703 }
1728-
1704+ mlfun = multinom; ptype="probs"
1705+ mlfun = randomForest; ptype='prob'
1706+ # mlfun = imbalanced
17291707 # Base model (AGE and PTGENDER)
17301708 base_mod <- tryCatch(
1731- multinom (as.formula("diagnosis ~ AGE + PTGENDER"), data = train, trace = FALSE),
1709+ mlfun (as.formula("diagnosis ~ AGE + PTGENDER"), data = train, trace = FALSE),
17321710 error = function(e) {
17331711 cat("Error in base model for fold", f, "weight", ww, ":", e$message, "\n")
17341712 return(NULL)
@@ -1739,14 +1717,14 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17391717 # NSA full model
17401718 full_nsa_form <- paste("diagnosis ~", covarsbin, "+", paste(colnames(proj_nsa), collapse = "+"))
17411719 full_nsa <- tryCatch(
1742- multinom (as.formula(full_nsa_form), data = train, trace = FALSE),
1720+ mlfun (as.formula(full_nsa_form), data = train, trace = FALSE),
17431721 error = function(e) {
17441722 cat("Error in NSA model for fold", f, "weight", ww, ":", e$message, "\n")
17451723 return(NULL)
17461724 }
17471725 )
17481726 if (!is.null(full_nsa)) {
1749- preds_nsa <- predict(full_nsa, test, type = "probs" )
1727+ preds_nsa <- predict(full_nsa, test, type = ptype )
17501728 roc_nsa <- tryCatch(
17511729 multiclass.roc(test$diagnosis, preds_nsa, levels = levels(test$diagnosis), quiet = TRUE),
17521730 error = function(e) {
@@ -1762,14 +1740,14 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
17621740 # PCA full model
17631741 full_pca_form <- paste("diagnosis ~", covarsbin, "+", paste(colnames(proj_pca), collapse = "+"))
17641742 full_pca <- tryCatch(
1765- multinom (as.formula(full_pca_form), data = train, trace = FALSE),
1743+ mlfun (as.formula(full_pca_form), data = train, trace = FALSE),
17661744 error = function(e) {
17671745 cat("Error in PCA model for fold", f, "weight", ww, ":", e$message, "\n")
17681746 return(NULL)
17691747 }
17701748 )
17711749 if (!is.null(full_pca)) {
1772- preds_pca <- predict(full_pca, test, type = "probs" )
1750+ preds_pca <- predict(full_pca, test, type = ptype )
17731751 roc_pca <- tryCatch(
17741752 multiclass.roc(test$diagnosis, preds_pca, levels = levels(test$diagnosis), quiet = TRUE),
17751753 error = function(e) {
@@ -1800,6 +1778,81 @@ for (ww in seq(0.1, 0.9, by = 0.2 )) {
18001778 ))
18011779}
18021780
1781+
1782+ # Load required libraries
1783+ library(ggplot2)
1784+ library(dplyr)
1785+ library(tidyr)
1786+ library(gridExtra)
1787+ library(scales)
1788+ # Add experiment label (two runs)
1789+ log_results_df$experiment <- rep(c("Run 1", "Run 2"), each = length(mywws))
1790+
1791+ # === Long format for AUC lines ===
1792+ auc_long <- log_results_df %>%
1793+ dplyr::select(w, experiment, auc_nsa, auc_pca) %>%
1794+ pivot_longer(cols = c(auc_nsa, auc_pca), names_to = "method", values_to = "auc") %>%
1795+ mutate(method = ifelse(method == "auc_nsa", "NSA-ICA", "PCA"))
1796+
1797+ # === Plot 1: AUC vs w (NSA-ICA vs PCA) ===
1798+ p1 <- ggplot(auc_long, aes(x = w, y = auc, color = method, linetype = experiment)) +
1799+ geom_line(size = 1.2) +
1800+ geom_point(size = 2.5, shape = 21, fill = "white", stroke = 1.2) +
1801+ scale_color_manual(values = c("NSA-ICA" = "#D55E00", "PCA" = "#0072B2")) +
1802+ scale_linetype_manual(values = c("Run 1" = "solid", "Run 2" = "dashed")) +
1803+ geom_hline(yintercept = mean(data$random_accuracy), linetype = "dotted", color = "gray40", size = 1) +
1804+ annotate("text", x = 0.45, y = mean(data$random_accuracy) + 0.005,
1805+ label = "Random Chance (~0.406)", color = "gray40", size = 3.5, hjust = 0) +
1806+ labs(
1807+ title = "Cross-Validated AUC for Diagnosis Prediction",
1808+ subtitle = "NSA-ICA vs PCA across regularization strength (w)",
1809+ x = "NSA Regularization Weight (w)",
1810+ y = "Mean AUC (Cross-Validation)",
1811+ color = "Method",
1812+ linetype = "Experiment"
1813+ ) +
1814+ theme_minimal(base_size = 8) +
1815+ theme(
1816+ plot.title = element_text(face = "bold", size = 8),
1817+ plot.subtitle = element_text(color = "gray50"),
1818+ legend.position = "top",
1819+ panel.grid.minor = element_blank(),
1820+ axis.line = element_line(color = "gray70"),
1821+ axis.ticks = element_line(color = "gray70")
1822+ ) +
1823+ scale_y_continuous(limits = c(0.68, 0.79), breaks = seq(0.68, 0.79, 0.02))
1824+
1825+ # === Plot 2: AUC Difference (NSA - PCA) ===
1826+ p2 <- ggplot(log_results_df, aes(x = w, y = auc_diff, color = experiment)) +
1827+ geom_hline(yintercept = 0, linetype = "solid", color = "gray50") +
1828+ geom_line(size = 1.3) +
1829+ geom_point(size = 3, shape = 18) +
1830+ scale_color_manual(values = c("Run 1" = "#D55E00", "Run 2" = "#CC79A7")) +
1831+ labs(
1832+ title = "Performance Gain: NSA-ICA − PCA",
1833+ x = "NSA Regularization Weight (w)",
1834+ y = "ΔAUC (NSA-ICA − PCA)",
1835+ color = "Experiment"
1836+ ) +
1837+ theme_minimal(base_size = 8) +
1838+ theme(
1839+ plot.title = element_text(face = "bold", size = 8),
1840+ legend.position = "none",
1841+ panel.grid.minor = element_blank(),
1842+ axis.line = element_line(color = "gray70"),
1843+ axis.ticks = element_line(color = "gray70")
1844+ ) +
1845+ scale_y_continuous(labels = percent_format(accuracy = 0.1)) +
1846+ annotate("rect", xmin = -Inf, xmax = Inf, ymin = 0, ymax = Inf, fill = "green", alpha = 0.05)
1847+
1848+ # === Combine Plots ===
1849+ final_plot <- gridExtra::grid.arrange(p1, p2, ncol = 1, heights = c(1.8, 1))
1850+ print(final_plot)
1851+ ```
1852+
1853+
1854+ ``` {r enhanced-stats-dx,echo=FALSE,fig.width=7,fig.cap="Population-level distribution of NSA-based embeddings for different diagnostic groups. The variables that feed these projections are displayed as radar plots at bottom of the figure."}
1855+
18031856# Summary for logistic regression
18041857log_summary <- log_results_df %>%
18051858 summarise(
@@ -1871,7 +1924,6 @@ for (comp in psel ) {
18711924# grid.arrange( grobs = gglist, ncol = 2 , top="Distributions of Significant NSA Components by Diagnosis" )
18721925
18731926# gt::gt(log_results_df,caption='Cross-Validated AUC Summary for Diagnosis Prediction: Full results.')
1874-
18751927gt::gt(log_summary,caption='Cross-Validated AUC Summary for Diagnosis Prediction: Statistical Summary.')
18761928
18771929# Plot results
0 commit comments