You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#' @title NSA-Flow Optimization via PyTorch AutoGrad
273
+
#'
274
+
#' @description
275
+
#' Performs optimization to balance fidelity to a target matrix and orthogonality
276
+
#' of the solution matrix using a weighted objective function. The function supports multiple retraction methods and includes robust convergence checks. These constraints provide global control over column-wise sparseness by projecting the matrix onto the approximate Stiefel manifold.
277
+
#'
278
+
#' @param Y0 Numeric matrix of size \code{p x k}, the initial guess for the solution.
279
+
#' @param X0 Numeric matrix of size \code{p x k}, the target matrix for fidelity.
280
+
#' If \code{NULL}, initialized as \code{pmax(Y0, 0)} with a small perturbation added to \code{Y0}.
281
+
#' @param w Numeric scalar in \code{[0,1]}, weighting the trade-off between fidelity
282
+
#' (1 - w) and orthogonality (w). Default is 0.5.
283
+
#' @param retraction Character string specifying the retraction method to enforce
284
+
#' orthogonality constraints.
285
+
#' @param max_iter Integer, maximum number of iterations. Default is 100.
286
+
#' @param tol Numeric, tolerance for convergence based on relative gradient norm
287
+
#' and energy stability. Default is 1e-6.
288
+
#' @param verbose Logical, if \code{TRUE}, prints iteration details. Default is \code{FALSE}.
289
+
#' @param seed Integer, random seed for reproducibility. If \code{NULL}, no seed is set.
290
+
#' Default is 42.
291
+
#' @param apply_nonneg Logical, if \code{TRUE}, enforces non-negativity on the solution
292
+
#' after retraction. Default is \code{TRUE}.
293
+
#' @param optimizer Character string, optimization algorithm to use. The "fast" option
294
+
#' will select the best option based on whether simplified = TRUE or FALSE.
295
+
#' otherwise, pass the names of optimizers supported by \code{create_optimizer()}
296
+
#' as seen in \code{list_simlr_optimizers()}. Default is "fast".
297
+
#' @param initial_learning_rate Numeric, initial learning rate for the optimizer.
298
+
#' Default is 1e-3 for non-neg and 1 for unconstrained. Otherwise, you can use \code{estimate_learning_rate_nsa()} to find a robust value.
299
+
#'. pass a string one of c("brent", "grid", "armijo", "golden", "adaptive") to engage this method.
300
+
#' @param record_every Integer, frequency of recording iteration metrics.
301
+
#' Default is 1 (record every iteration).
302
+
#' @param window_size Integer, size of the window for energy stability convergence check.
303
+
#' Default is 5.
304
+
#' @param c1_armijo Numeric, Armijo condition constant for line search.
305
+
#' @param simplified Logical, if \code{TRUE}, uses the simplified objective
306
+
#' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}.
307
+
#' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}.
308
+
#' @param project_full_gradient Logical, if \code{TRUE}, projects the full gradient instead
309
+
#' of just the orthogonal component. Default is \code{FALSE}.
310
+
#' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality
311
+
#' traces with dual axes. Default is \code{FALSE}.
312
+
#' @param precision Character string, either 'float32' or 'float64' to set the precision
313
+
#'
314
+
#' @return A list containing:
315
+
#' \itemize{
316
+
#' \item \code{Y}: Numeric matrix, the best solution found (lowest total energy).
317
+
#' \item \code{traces}: Data frame with columns \code{iter}, \code{time},
318
+
#' \code{fidelity}, \code{orthogonality}, and \code{total_energy}
319
+
#' for recorded iterations.
320
+
#' \item \code{final_iter}: Integer, number of iterations performed.
321
+
#' \item \code{plot}: ggplot object of the optimization trace
0 commit comments