Skip to content

Commit c74c075

Browse files
committed
WIP: toward torch ag version
1 parent 87662a3 commit c74c075

File tree

5 files changed

+48
-21
lines changed

5 files changed

+48
-21
lines changed

R/multiscaleSVDxpts.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11748,7 +11748,7 @@ digraph NSA_Flow_FA {
1174811748
#' @param tol numeric, tolerance for relative parameter change convergence
1174911749
#' @param retraction retraction function or identifier (passed to nsa_flow)
1175011750
#' @param grad_tol numeric, gradient-norm tolerance for convergence
11751-
#' @param R optional, passed-through (not used here)
11751+
#' @param nsa_flow_fn optional, nsa_flow function to use (default: nsa_flow)
1175211752
#' @param verbose logical, print iteration diagnostics
1175311753
#' @param orth_every integer >=1, perform orthogonalization every this many iterations (default 5)
1175411754
#'
@@ -11771,7 +11771,7 @@ nsa_flow_pca <- function(X, k,
1177111771
w_pca = 1.0, nsa_w = 0.5,
1177211772
apply_soft_thresh_in_nns = FALSE,
1177311773
tol = 1e-6, retraction = def_ret,
11774-
grad_tol = 1e-4, R = NULL, verbose = FALSE,
11774+
grad_tol = 1e-4, nsa_flow_fn = nsa_flow_autograd, verbose = FALSE,
1177511775
orth_every = 5) {
1177611776
# --- argument checks ---
1177711777
if (!is.matrix(X) || any(!is.finite(X))) stop("X must be a finite numeric matrix")
@@ -11883,7 +11883,7 @@ nsa_flow_pca <- function(X, k,
1188311883
} else if (proximal_type == "nsa_flow") {
1188411884
# call nsa_flow; we assume it takes arguments (Y0, X0=NULL, w=..., retraction=...)
1188511885
# use X0 = NULL to indicate proximal-only processing of Y_ret
11886-
prox_res <- nsa_flow(Y_ret, X0 = NULL, w = nsa_w, retraction = retraction)
11886+
prox_res <- nsa_flow_fn( Y_ret, nsa_w )
1188711887
if (!is.list(prox_res) || is.null(prox_res$Y)) stop("nsa_flow returned unexpected result")
1188811888
Y_new <- prox_res$Y
1188911889
} else {

R/nsa_flow_torch.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ nsa_flow_torch_ag <- function(
555555
#' @param optimizer character optimizer name (e.g. 'Adam','lars','sgdp')
556556
#' @param initial_learning_rate NULL (auto), numeric, or character strategy string
557557
#' @param lr_strategy character passed to Python if initial_learning_rate is NULL/'auto'
558+
#' @param aggression numeric controls aggressiveness of learning rate adaptation
558559
#' @param fidelity_type character ('basic','scale_invariant','symmetric','normalized')
559560
#' @param orth_type character ('basic','scale_invariant')
560561
#' @param record_every integer frequency of recording traces
@@ -576,6 +577,7 @@ nsa_flow_autograd <- function(
576577
optimizer = "Adam",
577578
initial_learning_rate = NULL,
578579
lr_strategy = "auto",
580+
aggression = 0.5,
579581
fidelity_type = "scale_invariant",
580582
orth_type = "scale_invariant",
581583
record_every = 1L,

man/nsa_flow_autograd.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_pca.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vignettes/nsa_flow.Rmd

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,14 @@ Y0_toy <- matrix(runif(12, 0, 1), 4, 3)
327327
328328
# Apply NSA-Flow with balanced weights
329329
# X0_toy=X0_toy/norm(X0_toy, "F")*0.1 # Normalize
330-
omega_default = 0.01
330+
omega_default = 0.05
331331
# if ( ! exists("ini_default") )
332-
ini_default = 'armijo_aggressive' #
333-
optype='asgd' # for torch backend
332+
lropts=c('armijo', 'armijo_aggressive', 'exponential', 'linear', 'random', 'adaptive', 'momentum_boost', 'entropy', 'poly_decay', 'bayes')
333+
ini_default = 'armijo' #
334+
optype='lars' # for torch backend
334335
def_ret = "soft_polar"
335-
nsa_default <- function(Y0, w = omega_default, X0 = NULL, o=optype, init=ini_default,verbose = FALSE ) {
336+
nsa_default <- function(Y0, w = omega_default,
337+
X0 = NULL, init=ini_default, agg=0.8, verbose = FALSE ) {
336338
nsa_flow_autograd(
337339
Y0 = Y0,
338340
X0 = X0,
@@ -343,17 +345,35 @@ nsa_default <- function(Y0, w = omega_default, X0 = NULL, o=optype, init=ini_def
343345
seed = 42,
344346
apply_nonneg = TRUE,
345347
tol = 1e-6,
346-
window_size=5,
347-
fidelity_type = "scale_invariant",
348+
window_size=10,
349+
fidelity_type = "symmetric",
348350
orth_type = "scale_invariant",
349351
lr_strategy = init,
350-
optimizer = o,
352+
aggression = agg,
353+
optimizer = optype,
351354
plot = TRUE
352355
)
353356
}
354-
355-
356-
res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
357+
#
358+
# res_toy <- nsa_default(Y0 = X0_toy, X0 = true_Y, w = omega_default )
359+
#
360+
res_toy = nsa_flow_autograd(
361+
Y0 = X0_toy,
362+
X0 = true_Y,
363+
w = omega_default,
364+
retraction = def_ret,
365+
max_iter = 500,
366+
verbose = FALSE,
367+
seed = 42,
368+
apply_nonneg = TRUE,
369+
tol = 1e-6,
370+
window_size=5,
371+
fidelity_type = "scale_invariant",
372+
orth_type = "scale_invariant",
373+
lr_strategy = 'armijo',
374+
optimizer = optype,
375+
plot = TRUE
376+
)
357377
358378
# Visualize
359379
@@ -572,12 +592,13 @@ X0 = generate_synth_data( p, k, corrval=0.35, noise=0.05, sparse_prob=0.0, inclu
572592
###
573593
w_seq <- c( 0.005, 0.05, 0.1, 0.2, 0.5 )
574594
w_seq <- c( 0.001, 0.005, 0.01, 0.05, 0.25 )
595+
w_seq <- c( 0.1, 0.25, 0.5, 0.75, 0.9 )
575596
mytit = paste0("w = ", round(w_seq,3))
576597
mats <- list()
577598
convergeplots <- list()
578599
for(i in seq_along(w_seq)) {
579600
w_val <- w_seq[i]
580-
res_soft_w <- nsa_default( X0, w = w_val, o=optype, verbose = FALSE )
601+
res_soft_w <- nsa_default( X0, w = w_val, verbose = FALSE )
581602
mytit[i] <- paste0("w = ", round(w_val, 3), ', orth = ',
582603
round(invariant_orthogonality_defect(res_soft_w$Y),4), ', w.spar = ',
583604
1.0-round(sum(res_soft_w$Y/max(res_soft_w$Y) > quantile(res_soft_w$Y,0.1))/length(res_soft_w$Y),3))
@@ -604,7 +625,7 @@ if ( length(convergeplots) >=4 ) {
604625
grid.arrange(grobs=convergeplots[c(1,2,3,5)], top='Convergence Plots for Different w Values', ncol=2 )
605626
}
606627
607-
#----------# darkk #
628+
# darkk #
608629
####################
609630
```
610631

@@ -912,7 +933,7 @@ X=generate_synth_data( p=100, k=20, corrval=0.35)$Y0
912933
nembed = 4
913934
# --- Compute results for both methods ---
914935
res_soft <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5,proximal_type='basic', verbose = FALSE )
915-
res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5,proximal_type='nsa_flow', verbose = FALSE )
936+
res_nns <- nsa_flow_pca( X, nembed, lambda = 0.05, alpha = 0.1, max_iter = 200, nsa_w = 0.5, tol = 1e-5, proximal_type='nsa_flow', nsa_flow_fn = nsa_default, verbose = FALSE )
916937
917938
918939
@@ -1054,10 +1075,10 @@ proj_std <- pca_std$x
10541075
10551076
res_basic <- nsa_flow_pca(golub_scaled_ss, myk,lambda = 0.1, alpha = 0.001,
10561077
max_iter = mxit, proximal_type = "basic", tol = 1e-5,
1057-
nsa_w = 0.5, verbose = F)
1078+
nsa_w = omega_default, verbose = F)
10581079
res_nns <- nsa_flow_pca(golub_scaled_ss, myk, lambda = 0.1, alpha = 0.001,
10591080
max_iter = mxit, proximal_type = "nsa_flow", tol = 1e-5,
1060-
nsa_w = 0.5, verbose = F)
1081+
nsa_w = omega_default, nsa_flow_fn = nsa_default, verbose = FALSE)
10611082
10621083
## --- Core Metrics ------------------------------------------------------------
10631084
metrics_pca_g <- compute_core_metrics(pca_std$rotation, golub_scaled_ss)
@@ -1088,6 +1109,7 @@ golub_metrics <- tibble(
10881109
CV_Accuracy = c(acc_std$Accuracy, acc_basic$Accuracy, acc_nns$Accuracy),
10891110
CV_Accuracy_SD = c(acc_std$AccuracySD, acc_basic$AccuracySD, acc_nns$AccuracySD)
10901111
)
1112+
#####
10911113
```
10921114

10931115

@@ -1361,7 +1383,7 @@ results_df <- data.frame(
13611383
for (ww in wws) {
13621384
# cat(paste0("Running NSA-Flow with optimizer = ", oo, " and w = ", ww, "\n"))
13631385
1364-
M_nsa <- nsa_default( Y0_pca, w = ww, o=optype, verbose = FALSE )
1386+
M_nsa <- nsa_default( Y0_pca, w = ww, verbose = FALSE )
13651387
13661388
if ( any( apply( M_nsa$Y, FUN=var, MARGIN=2) == 0 ) ) {
13671389
cat("Warning: zero-variance component detected, skipping this run.\n")

0 commit comments

Comments
 (0)