Skip to content

Commit 844f456

Browse files
committed
ENH: >..<
1 parent a55c32d commit 844f456

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

R/multiscaleSVDxpts.R

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11533,22 +11533,19 @@ nsa_flow_fa <- function(
1153311533

1153411534
# Anneal w if enabled (start from 0, linear to nsa_w)
1153511535
w_iter <- if (anneal_w) nsa_w * (iter / max_iter) else nsa_w
11536-
1153711536
# Apply NSA-Flow regularization with incremental update (Y0 from power loadings)
11538-
nsa_result <- tryCatch({
11539-
nsa_flow(
11537+
nsa_result = nsa_flow(
1154011538
Y0 = as.matrix(loadings_power),
1154111539
w = w_iter,
1154211540
max_iter = nsa_max_iter,
1154311541
retraction = "soft_polar",
11542+
seed = 42,
11543+
tol = 1e-6,
11544+
window_size=10,
1154411545
plot = TRUE,
11546+
verbose=FALSE,
1154511547
...
1154611548
)
11547-
}, error = function(e) {
11548-
warning("nsa_flow failed at iteration ", iter, ": ", conditionMessage(e))
11549-
NULL
11550-
})
11551-
1155211549
if (!is.null(nsa_result) && is.matrix(nsa_result$Y)) {
1155311550
loadings_post <- nsa_result$Y
1155411551
final_nsa_result <- nsa_result
@@ -11800,9 +11797,11 @@ nsa_flow_pca <- function(X, k,
1180011797
if (!is.finite(total_var) || total_var <= 0) stop("Input matrix X has zero or non-finite variance")
1180111798

1180211799
# --- SVD init: Y is p x k ---
11803-
sv <- svd(Xc, nu = 0, nv = k)
11804-
Y <- sv$v
11805-
if (ncol(Y) != k) stop("SVD initialization did not produce k columns")
11800+
set.seed(1234)
11801+
Y = matrix(rnorm(p * k), nrow = p, ncol = k)
11802+
# sv <- svd(Xc, nu = 0, nv = k)
11803+
# Y <- sv$v
11804+
# if (ncol(Y) != k) stop("SVD initialization did not produce k columns")
1180611805

1180711806
# bookkeeping
1180811807
energy_trace <- numeric(max_iter)
@@ -11919,7 +11918,7 @@ nsa_flow_pca <- function(X, k,
1191911918

1192011919
# check for improvement
1192111920
improved <- FALSE
11922-
if (energy < best_energy - min_delta) {
11921+
if (energy < best_energy ) {
1192311922
best_energy <- energy
1192411923
best_Y <- Y_new
1192511924
improved <- TRUE

R/nsa_flow_torch.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,9 @@ nsa_flow_autograd <- function(
574574
verbose = FALSE,
575575
seed = 42L,
576576
apply_nonneg = TRUE,
577-
optimizer = "Adam",
577+
optimizer = "asgd",
578578
initial_learning_rate = NULL,
579-
lr_strategy = "auto",
579+
lr_strategy = "bayes",
580580
aggression = 0.5,
581581
fidelity_type = "scale_invariant",
582582
orth_type = "scale_invariant",

man/nsa_flow_autograd.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vignettes/nsa_flow_fa.Rmd

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ library(corrplot)
2424
library(ggplot2)
2525
library(patchwork)
2626
library(reshape2)
27+
library(ANTsR)
2728
set.seed(1234)
29+
2830
```
2931

3032
# Introduction
@@ -244,8 +246,6 @@ NSA-Flow-FA combines traditional statistical factor modeling with modern geometr
244246
For predictive comparisons, the dataset is split into 80% training and 20% testing sets. Models are fitted on the training set, and evaluations are performed on the test set.
245247

246248
```{r methods,echo=FALSE}
247-
248-
library(ANTsR)
249249
paf_standalone <- function(data = NULL, R = NULL, nfactors, rotate = "none", max_iter = 100, tol = 1e-5) {
250250
if (!is.null(data)) R <- cor(data)
251251
fit <- psych::fa(R, nfactors = nfactors, rotate = rotate, fm = "pa", max.iter = max_iter)
@@ -433,19 +433,14 @@ nsa_fa_flow <- function(
433433
w_iter <- if (anneal_w) nsa_w * (iter / max_iter) else nsa_w
434434
435435
# Apply NSA-Flow regularization with incremental update (Y0 from power loadings)
436-
nsa_result <- tryCatch({
437-
nsa_flow(
436+
nsa_result <- nsa_flow(
438437
Y0 = as.matrix(loadings_power),
439438
w = w_iter,
440439
max_iter = nsa_max_iter,
441440
retraction = "soft_polar",
442441
plot = TRUE,
443442
...
444443
)
445-
}, error = function(e) {
446-
warning("nsa_flow failed at iteration ", iter, ": ", conditionMessage(e))
447-
NULL
448-
})
449444
450445
if (!is.null(nsa_result) && is.matrix(nsa_result$Y)) {
451446
loadings_post <- nsa_result$Y
@@ -680,21 +675,21 @@ nrow(bfi_data)
680675

681676
```{r real-data-fit}
682677
# dook
683-
nsaval=0.50
684-
# c("brent", "grid", "armijo", "golden", "adaptive", "default")
685-
lrval='adaptive'
678+
nsaval=0.80
679+
lrval='bayes'
680+
myo='asgd'
686681
myo='lars'
687-
nsa_max_iter=500
682+
nsa_max_iter=50
688683
if ( ! exists("nn") ) nn=TRUE
689684
rot='varimax'
685+
agg=0.1
690686
# bfi_data_tx=transform_matrix(data.matrix(bfi_data),'frob')$Xs*sqrt(prod(dim(bfi_data)))
691687
bfi_data_tx=transform_matrix(data.matrix(bfi_data),'minmax')$Xs
692688
#############################################################################
693689
if ( ! exists("real_A") | ! exists("real_B") | TRUE ) {
694690
real_A <- paf_standalone(bfi_data_tx, nfactors = 5, rotate = rot)
695-
real_B <- nsa_fa_flow( data = bfi_data_tx, nfactors = 5, max_iter = 50,
696-
rotate = rot, nsa_w = nsaval, initial_learning_rate = lrval,
697-
optimizer = myo, nsa_max_iter = nsa_max_iter, apply_nonneg = nn )
691+
real_B <- nsa_fa_flow( data = bfi_data_tx, nfactors = 5, max_iter = 50, rotate = rot, nsa_w = nsaval, optimizer = myo, nsa_max_iter = nsa_max_iter, apply_nonneg = nn)
692+
#, aggression=agg, lr_strategy = lrval )
698693
}
699694
#############################################################################
700695
```
@@ -948,8 +943,9 @@ L_A <- real_A_train$loadings
948943
949944
# ---- Method B: NSA.FA ----
950945
real_B_train <-nsa_fa_flow( data = train_data, nfactors = 5, max_iter = 50,
951-
rotate = rot, nsa_w = nsaval, initial_learning_rate = lrval,
946+
rotate = rot, nsa_w = nsaval,
952947
optimizer = myo, nsa_max_iter = nsa_max_iter, apply_nonneg = nn )
948+
# lr_strategy = lrval, aggression=agg )
953949
954950
L_B <- real_B_train$loadings
955951

0 commit comments

Comments
 (0)