Skip to content

Commit 87662a3

Browse files
committed
WIP: update to nsa flow 0.7.5
1 parent 0a877ff commit 87662a3

File tree

6 files changed

+322
-31
lines changed

6 files changed

+322
-31
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ export(n4BiasFieldCorrection)
332332
export(neg_violation)
333333
export(networkEiganat)
334334
export(nsa_flow)
335+
export(nsa_flow_autograd)
335336
export(nsa_flow_fa)
336337
export(nsa_flow_fa_diagram)
337338
export(nsa_flow_flowchart)

R/nsa_flow_torch.R

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,228 @@ nsa_flow_torch_ag <- function(
533533
plot = if (plot) energy_plot else NULL
534534
)
535535
}
536+
537+
538+
539+
#' @title NSA-Flow Optimization via PyTorch AutoGrad (R wrapper)
540+
#'
541+
#' @description
542+
#' Mirror of Python `nsa_flow_autograd()` but callable from R via reticulate.
543+
#' Uses the Python implementation (nsa_flow.nsa_flow_autograd) and returns
544+
#' R-friendly results (matrix, data.frame, ggplot).
545+
#'
546+
#' @param Y0 numeric matrix p x k initial guess
547+
#' @param X0 numeric matrix p x k target (or NULL to initialize from Y0)
548+
#' @param w numeric in [0,1] weighting fidelity vs orthogonality
549+
#' @param retraction character retraction method
550+
#' @param max_iter integer max iterations
551+
#' @param tol numeric convergence tolerance
552+
#' @param verbose logical
553+
#' @param seed integer random seed
554+
#' @param apply_nonneg logical or 'softplus'/'none' etc.
555+
#' @param optimizer character optimizer name (e.g. 'Adam','lars','sgdp')
556+
#' @param initial_learning_rate NULL (auto), numeric, or character strategy string
557+
#' @param lr_strategy character passed to Python if initial_learning_rate is NULL/'auto'
558+
#' @param fidelity_type character ('basic','scale_invariant','symmetric','normalized')
559+
#' @param orth_type character ('basic','scale_invariant')
560+
#' @param record_every integer frequency of recording traces
561+
#' @param window_size integer window for energy stability
562+
#' @param plot logical produce ggplot (default FALSE)
563+
#' @param precision 'float32' or 'float64'
564+
#' @return list: Y (matrix), traces (data.frame), final_iter, best_total_energy, best_Y_iteration, plot (ggplot or NULL), settings
565+
#' @export
566+
nsa_flow_autograd <- function(
567+
Y0,
568+
X0 = NULL,
569+
w = 0.5,
570+
retraction = c("soft_polar", "polar", "none"),
571+
max_iter = 500L,
572+
tol = 1e-6,
573+
verbose = FALSE,
574+
seed = 42L,
575+
apply_nonneg = TRUE,
576+
optimizer = "Adam",
577+
initial_learning_rate = NULL,
578+
lr_strategy = "auto",
579+
fidelity_type = "scale_invariant",
580+
orth_type = "scale_invariant",
581+
record_every = 1L,
582+
window_size = 5L,
583+
plot = FALSE,
584+
precision = "float64"
585+
) {
586+
# basic checks
587+
if (!is.matrix(Y0)) stop("Y0 must be a numeric matrix.")
588+
if (!is.null(X0) && !is.matrix(X0)) stop("X0 must be a numeric matrix or NULL.")
589+
retraction <- match.arg(retraction)
590+
591+
# import python modules
592+
reticulate::py_config() # (optional debug) - remove if noisy
593+
torch <- reticulate::import("torch", convert = FALSE)
594+
pynsa <- reticulate::import("nsa_flow", convert = FALSE)
595+
596+
if (is.null(pynsa)) stop("Could not import python package 'nsa_flow'")
597+
598+
# convert Y0 and X0 to torch tensors (no convert)
599+
dtype_str <- ifelse(precision == "float32", "float32", "float64")
600+
torch_dtype <- if (dtype_str == "float32") torch$float32 else torch$float64
601+
602+
Y_torch <- torch$tensor(Y0, dtype = torch_dtype)
603+
X_torch <- if (is.null(X0)) reticulate::r_to_py(NULL) else torch$tensor(X0, dtype = torch_dtype)
604+
605+
# prepare initial_learning_rate argument for Python call
606+
if (is.null(initial_learning_rate)) {
607+
py_init_lr <- reticulate::r_to_py(NULL)
608+
py_lr_strategy <- lr_strategy
609+
} else if (is.character(initial_learning_rate)) {
610+
# pass through strategy string (e.g. "armijo", "auto")
611+
py_init_lr <- reticulate::r_to_py(initial_learning_rate)
612+
py_lr_strategy <- initial_learning_rate
613+
} else if (is.numeric(initial_learning_rate)) {
614+
py_init_lr <- reticulate::r_to_py(as.double(initial_learning_rate))
615+
py_lr_strategy <- NA_character_
616+
} else {
617+
stop("initial_learning_rate must be NULL, a numeric, or a strategy string.")
618+
}
619+
620+
# normalize apply_nonneg options: allow TRUE, FALSE, "softplus", "none"
621+
apply_nonneg_py <- switch(
622+
as.character(apply_nonneg),
623+
"TRUE" = reticulate::r_to_py(TRUE),
624+
"FALSE" = reticulate::r_to_py(FALSE),
625+
"softplus" = reticulate::r_to_py("softplus"),
626+
"none" = reticulate::r_to_py("none"),
627+
reticulate::r_to_py(apply_nonneg)
628+
)
629+
630+
# Call python function; be defensive about argument names (match Python signature)
631+
py_args <- list(
632+
Y0 = Y_torch,
633+
X0 = X_torch,
634+
w = as.double(w),
635+
retraction = retraction,
636+
max_iter = as.integer(max_iter),
637+
tol = as.double(tol),
638+
verbose = reticulate::r_to_py(verbose),
639+
seed = as.integer(seed),
640+
apply_nonneg = apply_nonneg_py,
641+
optimizer = optimizer,
642+
initial_learning_rate = py_init_lr,
643+
lr_strategy = lr_strategy,
644+
fidelity_type = fidelity_type,
645+
orth_type = orth_type,
646+
record_every = as.integer(record_every),
647+
window_size = as.integer(window_size),
648+
precision = precision
649+
)
650+
651+
# Try to call python function and handle errors clearly
652+
res_py <- tryCatch(
653+
{
654+
do.call(pynsa$nsa_flow_autograd, py_args)
655+
},
656+
error = function(e) {
657+
stop("Error calling Python nsa_flow_autograd():\n", e$message)
658+
}
659+
)
660+
661+
# Convert outputs
662+
# Expect res_py to be a dict-like object with keys "Y", "traces", "final_iter", "best_total_energy", "best_Y_iteration", "target", "settings"
663+
# Convert Y (torch tensor) -> numeric matrix
664+
Y_out <- NULL
665+
if (!is.null(res_py$Y)) {
666+
# res_py$Y might be a torch tensor; convert safely
667+
# Use detach() if available
668+
try({
669+
# If it's a tensor object, call detach().numpy(); else, try py_to_r
670+
if (!is.null(res_py$Y$detach)) {
671+
Y_out <- as.matrix(res_py$Y$detach()$cpu()$numpy())
672+
} else {
673+
Y_out <- reticulate::py_to_r(res_py$Y)
674+
}
675+
}, silent = TRUE)
676+
677+
# fallback
678+
if (is.null(Y_out)) {
679+
Y_out <- tryCatch(reticulate::py_to_r(res_py$Y), error = function(e) NULL)
680+
}
681+
}
682+
683+
# Convert traces to data.frame if possible
684+
traces_df <- NULL
685+
if (!is.null(res_py$traces)) {
686+
traces_df <- tryCatch({
687+
reticulate::py_to_r(res_py$traces)
688+
}, error = function(e) {
689+
# If traces is a list of dicts, convert manually
690+
tlist <- reticulate::py_to_r(res_py$traces)
691+
if (is.list(tlist) && length(tlist) > 0 && is.list(tlist[[1]])) {
692+
do.call(rbind, lapply(tlist, function(x) as.data.frame(x, stringsAsFactors = FALSE)))
693+
} else {
694+
as.data.frame(tlist)
695+
}
696+
})
697+
# ensure rownames removed
698+
rownames(traces_df) <- NULL
699+
}
700+
701+
# final iter & energy
702+
final_iter <- tryCatch(reticulate::py_to_r(res_py$final_iter), error = function(e) NA)
703+
best_total_energy <- tryCatch(reticulate::py_to_r(res_py$best_total_energy), error = function(e) NA)
704+
best_Y_iteration <- tryCatch(reticulate::py_to_r(res_py$best_Y_iteration), error = function(e) NA)
705+
settings <- tryCatch(reticulate::py_to_r(res_py$settings), error = function(e) NULL)
706+
707+
# Build ggplot trace if requested and data available
708+
energy_plot <- NULL
709+
if (plot && !is.null(traces_df) && nrow(traces_df) > 0) {
710+
if (!("fidelity" %in% colnames(traces_df)) || !("orthogonality" %in% colnames(traces_df))) {
711+
# try to coerce likely-named columns
712+
possible_fid <- grep("fid", names(traces_df), value = TRUE, ignore.case = TRUE)
713+
possible_orth <- grep("orth", names(traces_df), value = TRUE, ignore.case = TRUE)
714+
if (length(possible_fid) >= 1) names(traces_df)[which(names(traces_df) == possible_fid[1])] <- "fidelity"
715+
if (length(possible_orth) >= 1) names(traces_df)[which(names(traces_df) == possible_orth[1])] <- "orthogonality"
716+
}
717+
718+
if ("fidelity" %in% colnames(traces_df) && "orthogonality" %in% colnames(traces_df)) {
719+
max_fid <- max(traces_df$fidelity, na.rm = TRUE)
720+
max_orth <- max(traces_df$orthogonality, na.rm = TRUE)
721+
ratio <- if (max_orth > 0) max_fid / max_orth else 1
722+
723+
energy_plot <- ggplot2::ggplot(traces_df, ggplot2::aes(x = iter)) +
724+
ggplot2::geom_line(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.1) +
725+
ggplot2::geom_point(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.2, alpha = 0.7) +
726+
ggplot2::geom_line(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.1) +
727+
ggplot2::geom_point(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.2, alpha = 0.7) +
728+
ggplot2::scale_y_continuous(name = "Fidelity Energy",
729+
sec.axis = ggplot2::sec_axis(~ . / ratio, name = "Orthogonality Defect")) +
730+
ggplot2::scale_color_manual(values = c("Fidelity" = "#1f78b4", "Orthogonality" = "#33a02c")) +
731+
ggplot2::labs(title = paste("NSA-Flow Optimization Trace:", retraction),
732+
subtitle = paste0("fidelity_type=", fidelity_type, ", orth_type=", orth_type),
733+
x = "Iteration", color = "Term") +
734+
ggplot2::theme_minimal(base_size = 13) +
735+
ggplot2::theme(plot.title = ggplot2::element_text(face = "bold", hjust = 0.5),
736+
legend.position = "top")
737+
}
738+
}
739+
740+
# Rescale target/back to original magnitude if Python returned scaled Y or target info:
741+
# If Python returned 'target' or 'settings' containing scale_ref, you could rescale;
742+
# here we assume outputs are in original scale or the Python function already rescaled.
743+
# Provide Y_out as numeric matrix; maintain dimnames from input
744+
if (!is.null(Y_out)) {
745+
dimnames(Y_out) <- list(rownames(Y0), colnames(Y0))
746+
}
747+
748+
out <- list(
749+
Y = Y_out,
750+
traces = traces_df,
751+
final_iter = final_iter,
752+
best_total_energy = best_total_energy,
753+
best_Y_iteration = best_Y_iteration,
754+
plot = if (plot) energy_plot else NULL,
755+
settings = settings
756+
)
757+
758+
class(out) <- c("nsa_flow_result", class(out))
759+
return(out)
760+
}

docs/compare_optimizers_for_NSA-Flow.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ if (exists("matoption") && matoption == "wide") mat1 <- t(mat1)
2727
ws <- seq(0.05, 0.95, by = 0.05)
2828
optimizers <- list_simlr_optimizers( torch=TRUE )
2929
goptimizers = optimizers[ !(optimizers %in% c("adamp","lbfgs","padam","sgdp","adabound","adamax"))]
30-
max_iter <- 500
30+
max_iter <- 50
3131
tol <- 1e-6
3232
def_ret <- "soft_polar"
3333
if (!exists("raw_results")) {
@@ -37,7 +37,7 @@ if (!exists("raw_results")) {
3737
# print(paste("Running optimizer:", o, "with w =", round(w, 2)))
3838
start_time <- Sys.time()
3939
res = nsa_flow_torch_ag(mat1, w = w, verbose = F, retraction = def_ret,
40-
optimizer = o, initial_learning_rate=NA )
40+
optimizer = o, initial_learning_rate=NA, max_iter=max_iter, tol=tol)
4141
end_time <- Sys.time()
4242
runtime <- as.numeric(difftime(end_time, start_time, units = "secs"))
4343
if (!is.null(res)) {

man/list_simlr_optimizers.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_autograd.Rd

Lines changed: 72 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)