|
1 | | -#' Run NSA-Flow via PyTorch backend |
| 1 | +#' @title NSA-Flow Optimization via PyTorch |
2 | 2 | #' |
3 | | -#' @param Y A numeric matrix |
4 | | -#' @param Xc A numeric matrix |
5 | | -#' @param w_pca,lambda,lr,max_iter,retraction_type,armijo_beta,armijo_c,tol,verbose Control parameters |
| 3 | +#' @description |
| 4 | +#' Performs optimization to balance fidelity to a target matrix and orthogonality |
| 5 | +#' of the solution matrix using a weighted objective function. The function supports multiple retraction methods and includes robust convergence checks. These constraints provide global control over column-wise sparseness by projecting the matrix onto the approximate Stiefel manifold. |
6 | 6 | #' |
7 | | -#' @return A list with elements `Y`, `energy`, and `iter` |
| 7 | +#' @param Y0 Numeric matrix of size \code{p x k}, the initial guess for the solution. |
| 8 | +#' @param X0 Numeric matrix of size \code{p x k}, the target matrix for fidelity. |
| 9 | +#' If \code{NULL}, initialized as \code{pmax(Y0, 0)} with a small perturbation added to \code{Y0}. |
| 10 | +#' @param w Numeric scalar in \code{[0,1]}, weighting the trade-off between fidelity |
| 11 | +#' (1 - w) and orthogonality (w). Default is 0.5. |
| 12 | +#' @param retraction Character string specifying the retraction method to enforce |
| 13 | +#' orthogonality constraints. |
| 14 | +#' @param max_iter Integer, maximum number of iterations. Default is 100. |
| 15 | +#' @param tol Numeric, tolerance for convergence based on relative gradient norm |
| 16 | +#' and energy stability. Default is 1e-6. |
| 17 | +#' @param verbose Logical, if \code{TRUE}, prints iteration details. Default is \code{FALSE}. |
| 18 | +#' @param seed Integer, random seed for reproducibility. If \code{NULL}, no seed is set. |
| 19 | +#' Default is 42. |
| 20 | +#' @param apply_nonneg Logical, if \code{TRUE}, enforces non-negativity on the solution |
| 21 | +#' after retraction. Default is \code{TRUE}. |
| 22 | +#' @param optimizer Character string, optimization algorithm to use. The "fast" option |
| 23 | +#' will select the best option based on whether simplified = TRUE or FALSE. |
| 24 | +#' otherwise, pass the names of optimizers supported by \code{create_optimizer()} |
| 25 | +#' as seen in \code{list_simlr_optimizers()}. Default is "fast". |
| 26 | +#' @param initial_learning_rate Numeric, initial learning rate for the optimizer. |
| 27 | +#' Default is 1e-3 for non-neg and 1 for unconstrained. Otherwise, you can use \code{estimate_learning_rate_nsa()} to find a robust value. |
| 28 | +#'. pass a string one of c("brent", "grid", "armijo", "golden", "adaptive") to engage this method. |
| 29 | +#' @param record_every Integer, frequency of recording iteration metrics. |
| 30 | +#' Default is 1 (record every iteration). |
| 31 | +#' @param window_size Integer, size of the window for energy stability convergence check. |
| 32 | +#' Default is 5. |
| 33 | +#' @param c1_armijo Numeric, Armijo condition constant for line search. |
| 34 | +#' @param simplified Logical, if \code{TRUE}, uses the simplified objective |
| 35 | +#' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}. |
| 36 | +#' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}. |
| 37 | +#' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality |
| 38 | +#' traces with dual axes. Default is \code{FALSE}. |
| 39 | +#' |
| 40 | +#' @return A list containing: |
| 41 | +#' \itemize{ |
| 42 | +#' \item \code{Y}: Numeric matrix, the best solution found (lowest total energy). |
| 43 | +#' \item \code{traces}: Data frame with columns \code{iter}, \code{time}, |
| 44 | +#' \code{fidelity}, \code{orthogonality}, and \code{total_energy} |
| 45 | +#' for recorded iterations. |
| 46 | +#' \item \code{final_iter}: Integer, number of iterations performed. |
| 47 | +#' \item \code{plot}: ggplot object of the optimization trace |
| 48 | +#' (if \code{plot = TRUE}), otherwise \code{NULL}. |
| 49 | +#' \item \code{best_total_energy}: Numeric, the lowest total energy achieved. |
| 50 | +#' } |
| 51 | +#' |
| 52 | +#' @details |
| 53 | +#' The function minimizes a weighted objective combining fidelity to \code{X0} and |
| 54 | +#' orthogonality of \code{Y}, defined as: |
| 55 | +#' \deqn{E(Y) = (1 - w) * ||Y - X0||_F^2 / (2 * p * k) + w * defect(Y)} |
| 56 | +#' where \code{defect(Y)} measures orthogonality deviation. |
| 57 | +#' |
| 58 | +#' The optimization uses a Riemannian gradient descent approach with optional |
| 59 | +#' retraction to enforce orthogonality constraints. Convergence is checked via |
| 60 | +#' relative gradient norm and energy stability over a window of iterations. |
| 61 | +#' |
| 62 | +#' @examples |
| 63 | +#' set.seed(123) |
| 64 | +#' Y0 <- matrix(runif(20), 5, 4) |
| 65 | +#' X0 <- matrix(runif(20), 5, 4) |
| 66 | +#' # The original function relies on helper functions not shown here, such as: |
| 67 | +#' # create_optimizer, step, inv_sqrt_sym, symm, and invariant_orthogonality_defect. |
| 68 | +#' # The following example is conceptual: |
| 69 | +#' # result <- nsa_flow_torch(Y0, X0, w = 0.0, max_iter = 10, verbose = TRUE, plot = TRUE) |
| 70 | +#' # print(result$plot) |
| 71 | +#' # print(result$traces) |
| 72 | +#' |
| 73 | +#' @import ggplot2 |
| 74 | +#' @import reshape2 |
8 | 75 | #' @export |
9 | 76 | nsa_flow_torch <- function( |
10 | | - Y, Xc, |
11 | | - w_pca = 1.0, lambda = 0.01, |
12 | | - lr = 1e-2, max_iter = 100L, |
13 | | - retraction_type = c("soft_polar", "polar", "none"), |
14 | | - armijo_beta = 0.5, armijo_c = 1e-4, |
15 | | - tol = 1e-6, verbose = FALSE |
| 77 | + Y0, X0 = NULL, w = 0.5, |
| 78 | + retraction = c( "soft_polar", "polar", "none" ), |
| 79 | + max_iter = 500, tol = 1e-5, verbose = FALSE, seed = 42, |
| 80 | + apply_nonneg = TRUE, optimizer = "fast", |
| 81 | + initial_learning_rate = 'default', |
| 82 | + record_every = 1, window_size = 5, c1_armijo=1e-6, |
| 83 | + simplified = FALSE, |
| 84 | + plot = FALSE |
16 | 85 | ) { |
17 | | - retraction_type <- match.arg(retraction_type) |
| 86 | + if (!is.matrix(Y0)) { |
| 87 | + stop("Y0 must be a numeric matrix.") |
| 88 | + } |
| 89 | + |
| 90 | + retraction_type <- match.arg(retraction) |
18 | 91 |
|
19 | 92 | torch <- reticulate::import("torch", convert = FALSE) |
20 | 93 | pynsa <- reticulate::import("nsa_flow", convert=FALSE ) |
21 | 94 |
|
22 | 95 |
|
23 | 96 | if (is.null(pynsa) ) { |
24 | | - stop("Could not find Python function `nsa_flow_py` in the 'nsa_flow_py' module.") |
| 97 | + stop("Could not find Python package `nsa_flow` -- please install it first.") |
25 | 98 | } |
26 | 99 |
|
27 | | - Y_torch <- torch$tensor(Y, dtype = torch$float64) |
28 | | - Xc_torch <- torch$tensor(Xc, dtype = torch$float64) |
| 100 | + Y_torch <- torch$tensor(Y0, dtype = torch$float64) |
| 101 | + Xc_torch <- torch$tensor(X0, dtype = torch$float64) |
29 | 102 |
|
30 | | - res <- pynsa$nsa_flow_py( |
31 | | - Y_torch, Xc_torch, |
32 | | - w_pca = w_pca, |
33 | | - lambda_ = lambda, |
34 | | - lr = lr, |
| 103 | + res <- pynsa$nsa_flow( |
| 104 | + Y_torch, |
| 105 | + Xc_torch, |
| 106 | + w = w, |
| 107 | + retraction = retraction, |
35 | 108 | max_iter = as.integer(max_iter), |
36 | | - retraction_type = retraction_type, |
37 | | - armijo_beta = armijo_beta, |
38 | | - armijo_c = armijo_c, |
39 | 109 | tol = tol, |
40 | | - verbose = verbose |
| 110 | + verbose = verbose, |
| 111 | + seed = as.integer(seed), |
| 112 | + apply_nonneg = apply_nonneg, |
| 113 | + optimizer = optimizer, |
| 114 | + initial_learning_rate = initial_learning_rate, |
| 115 | + record_every = as.integer(record_every), |
| 116 | + window_size = as.integer(window_size), |
| 117 | + simplified = simplified, |
| 118 | + # |
| 119 | +# armijo_beta = armijo_beta, |
| 120 | +# armijo_c = armijo_c, |
41 | 121 | ) |
42 | 122 |
|
43 | 123 | list( |
44 | | - Y = as.matrix(res$Y$numpy()), |
45 | | - energy = reticulate::py_to_r(res$energy), |
46 | | - iter = reticulate::py_to_r(res$iter) |
| 124 | + Y = as.matrix(res$Y$detach()$numpy()), |
| 125 | + energy = reticulate::py_to_r(res$best_total_energy), |
| 126 | + iter = reticulate::py_to_r(res$final_iter) |
47 | 127 | ) |
48 | 128 | } |
0 commit comments