Skip to content

Commit 04daf8e

Browse files
committed
ENH: nsa flow torch with optimizers
1 parent b5c1476 commit 04daf8e

File tree

1 file changed

+107
-27
lines changed

1 file changed

+107
-27
lines changed

R/nsa_flow_torch.R

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,128 @@
1-
#' Run NSA-Flow via PyTorch backend
1+
#' @title NSA-Flow Optimization via PyTorch
22
#'
3-
#' @param Y A numeric matrix
4-
#' @param Xc A numeric matrix
5-
#' @param w_pca,lambda,lr,max_iter,retraction_type,armijo_beta,armijo_c,tol,verbose Control parameters
3+
#' @description
4+
#' Performs optimization to balance fidelity to a target matrix and orthogonality
5+
#' of the solution matrix using a weighted objective function. The function supports multiple retraction methods and includes robust convergence checks. These constraints provide global control over column-wise sparseness by projecting the matrix onto the approximate Stiefel manifold.
66
#'
7-
#' @return A list with elements `Y`, `energy`, and `iter`
7+
#' @param Y0 Numeric matrix of size \code{p x k}, the initial guess for the solution.
8+
#' @param X0 Numeric matrix of size \code{p x k}, the target matrix for fidelity.
9+
#' If \code{NULL}, initialized as \code{pmax(Y0, 0)} with a small perturbation added to \code{Y0}.
10+
#' @param w Numeric scalar in \code{[0,1]}, weighting the trade-off between fidelity
11+
#' (1 - w) and orthogonality (w). Default is 0.5.
12+
#' @param retraction Character string specifying the retraction method to enforce
13+
#' orthogonality constraints.
14+
#' @param max_iter Integer, maximum number of iterations. Default is 100.
15+
#' @param tol Numeric, tolerance for convergence based on relative gradient norm
16+
#' and energy stability. Default is 1e-6.
17+
#' @param verbose Logical, if \code{TRUE}, prints iteration details. Default is \code{FALSE}.
18+
#' @param seed Integer, random seed for reproducibility. If \code{NULL}, no seed is set.
19+
#' Default is 42.
20+
#' @param apply_nonneg Logical, if \code{TRUE}, enforces non-negativity on the solution
21+
#' after retraction. Default is \code{TRUE}.
22+
#' @param optimizer Character string, optimization algorithm to use. The "fast" option
23+
#' will select the best option based on whether simplified = TRUE or FALSE.
24+
#' otherwise, pass the names of optimizers supported by \code{create_optimizer()}
25+
#' as seen in \code{list_simlr_optimizers()}. Default is "fast".
26+
#' @param initial_learning_rate Numeric, initial learning rate for the optimizer.
27+
#' Default is 1e-3 for non-neg and 1 for unconstrained. Otherwise, you can use \code{estimate_learning_rate_nsa()} to find a robust value.
28+
#'. pass a string one of c("brent", "grid", "armijo", "golden", "adaptive") to engage this method.
29+
#' @param record_every Integer, frequency of recording iteration metrics.
30+
#' Default is 1 (record every iteration).
31+
#' @param window_size Integer, size of the window for energy stability convergence check.
32+
#' Default is 5.
33+
#' @param c1_armijo Numeric, Armijo condition constant for line search.
34+
#' @param simplified Logical, if \code{TRUE}, uses the simplified objective
35+
#' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}.
36+
#' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}.
37+
#' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality
38+
#' traces with dual axes. Default is \code{FALSE}.
39+
#'
40+
#' @return A list containing:
41+
#' \itemize{
42+
#' \item \code{Y}: Numeric matrix, the best solution found (lowest total energy).
43+
#' \item \code{traces}: Data frame with columns \code{iter}, \code{time},
44+
#' \code{fidelity}, \code{orthogonality}, and \code{total_energy}
45+
#' for recorded iterations.
46+
#' \item \code{final_iter}: Integer, number of iterations performed.
47+
#' \item \code{plot}: ggplot object of the optimization trace
48+
#' (if \code{plot = TRUE}), otherwise \code{NULL}.
49+
#' \item \code{best_total_energy}: Numeric, the lowest total energy achieved.
50+
#' }
51+
#'
52+
#' @details
53+
#' The function minimizes a weighted objective combining fidelity to \code{X0} and
54+
#' orthogonality of \code{Y}, defined as:
55+
#' \deqn{E(Y) = (1 - w) * ||Y - X0||_F^2 / (2 * p * k) + w * defect(Y)}
56+
#' where \code{defect(Y)} measures orthogonality deviation.
57+
#'
58+
#' The optimization uses a Riemannian gradient descent approach with optional
59+
#' retraction to enforce orthogonality constraints. Convergence is checked via
60+
#' relative gradient norm and energy stability over a window of iterations.
61+
#'
62+
#' @examples
63+
#' set.seed(123)
64+
#' Y0 <- matrix(runif(20), 5, 4)
65+
#' X0 <- matrix(runif(20), 5, 4)
66+
#' # The original function relies on helper functions not shown here, such as:
67+
#' # create_optimizer, step, inv_sqrt_sym, symm, and invariant_orthogonality_defect.
68+
#' # The following example is conceptual:
69+
#' # result <- nsa_flow_torch(Y0, X0, w = 0.0, max_iter = 10, verbose = TRUE, plot = TRUE)
70+
#' # print(result$plot)
71+
#' # print(result$traces)
72+
#'
73+
#' @import ggplot2
74+
#' @import reshape2
875
#' @export
976
nsa_flow_torch <- function(
10-
Y, Xc,
11-
w_pca = 1.0, lambda = 0.01,
12-
lr = 1e-2, max_iter = 100L,
13-
retraction_type = c("soft_polar", "polar", "none"),
14-
armijo_beta = 0.5, armijo_c = 1e-4,
15-
tol = 1e-6, verbose = FALSE
77+
Y0, X0 = NULL, w = 0.5,
78+
retraction = c( "soft_polar", "polar", "none" ),
79+
max_iter = 500, tol = 1e-5, verbose = FALSE, seed = 42,
80+
apply_nonneg = TRUE, optimizer = "fast",
81+
initial_learning_rate = 'default',
82+
record_every = 1, window_size = 5, c1_armijo=1e-6,
83+
simplified = FALSE,
84+
plot = FALSE
1685
) {
17-
retraction_type <- match.arg(retraction_type)
86+
if (!is.matrix(Y0)) {
87+
stop("Y0 must be a numeric matrix.")
88+
}
89+
90+
retraction_type <- match.arg(retraction)
1891

1992
torch <- reticulate::import("torch", convert = FALSE)
2093
pynsa <- reticulate::import("nsa_flow", convert=FALSE )
2194

2295

2396
if (is.null(pynsa) ) {
24-
stop("Could not find Python function `nsa_flow_py` in the 'nsa_flow_py' module.")
97+
stop("Could not find Python package `nsa_flow` -- please install it first.")
2598
}
2699

27-
Y_torch <- torch$tensor(Y, dtype = torch$float64)
28-
Xc_torch <- torch$tensor(Xc, dtype = torch$float64)
100+
Y_torch <- torch$tensor(Y0, dtype = torch$float64)
101+
Xc_torch <- torch$tensor(X0, dtype = torch$float64)
29102

30-
res <- pynsa$nsa_flow_py(
31-
Y_torch, Xc_torch,
32-
w_pca = w_pca,
33-
lambda_ = lambda,
34-
lr = lr,
103+
res <- pynsa$nsa_flow(
104+
Y_torch,
105+
Xc_torch,
106+
w = w,
107+
retraction = retraction,
35108
max_iter = as.integer(max_iter),
36-
retraction_type = retraction_type,
37-
armijo_beta = armijo_beta,
38-
armijo_c = armijo_c,
39109
tol = tol,
40-
verbose = verbose
110+
verbose = verbose,
111+
seed = as.integer(seed),
112+
apply_nonneg = apply_nonneg,
113+
optimizer = optimizer,
114+
initial_learning_rate = initial_learning_rate,
115+
record_every = as.integer(record_every),
116+
window_size = as.integer(window_size),
117+
simplified = simplified,
118+
#
119+
# armijo_beta = armijo_beta,
120+
# armijo_c = armijo_c,
41121
)
42122

43123
list(
44-
Y = as.matrix(res$Y$numpy()),
45-
energy = reticulate::py_to_r(res$energy),
46-
iter = reticulate::py_to_r(res$iter)
124+
Y = as.matrix(res$Y$detach()$numpy()),
125+
energy = reticulate::py_to_r(res$best_total_energy),
126+
iter = reticulate::py_to_r(res$final_iter)
47127
)
48128
}

0 commit comments

Comments
 (0)