Skip to content

Commit 6de6b05

Browse files
committed
ENH: autograd wrapped nsa_flow
1 parent e33cfb5 commit 6de6b05

File tree

3 files changed

+388
-18
lines changed

3 files changed

+388
-18
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ export(nsa_flow_pca)
339339
export(nsa_flow_retract)
340340
export(nsa_flow_retract_auto)
341341
export(nsa_flow_torch)
342+
export(nsa_flow_torch_ag)
342343
export(oneHotToSegmentation)
343344
export(optimal_simlr_initializer)
344345
export(optimize_indicator_matrix)

R/nsa_flow_torch.R

Lines changed: 267 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -226,24 +226,273 @@ nsa_flow_torch <- function(
226226
)
227227

228228

229-
df = as.data.frame(reticulate::py_to_r(res$traces))
230-
# Suppose your data frame is named df
231-
cols <- c("iter", "time", "fidelity", "orthogonality", "total_energy")
232-
# Find all iteration suffixes
233-
suffixes <- unique(gsub(".*\\.", "", grep("\\.", names(df), value = TRUE)))
234-
suffixes <- c("", sort(unique(suffixes))) # include first iteration (no suffix)
235-
236-
# Build rows for each suffix
237-
rows <- lapply(suffixes, function(suf) {
238-
postfix <- if (suf == "") "" else paste0(".", suf)
239-
subset <- df[ , paste0(cols, postfix), drop = FALSE]
240-
names(subset) <- cols
241-
subset
242-
})
243-
244-
# Bind all iterations together
245-
trace_df <- do.call(rbind, rows)
246-
rownames(trace_df) <- NULL
229+
trace_df = as.data.frame(reticulate::py_to_r(res$traces))
230+
231+
if (plot && !is.null(trace_df) && nrow(trace_df) > 0) {
232+
max_fid <- max(trace_df$fidelity, na.rm = TRUE)
233+
max_orth <- max(trace_df$orthogonality, na.rm = TRUE)
234+
ratio <- if (max_orth > 0) max_fid / max_orth else 1
235+
energy_plot <- ggplot2::ggplot(trace_df, ggplot2::aes(x = iter)) +
236+
ggplot2::geom_line(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.2) +
237+
ggplot2::geom_point(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.5, alpha = 0.7) +
238+
ggplot2::geom_line(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.2) +
239+
ggplot2::geom_point(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.5, alpha = 0.7) +
240+
ggplot2::scale_y_continuous(name = "Fidelity Energy",
241+
sec.axis = ggplot2::sec_axis(~ . / ratio, name = "Orthogonality Defect")) +
242+
ggplot2::scale_color_manual(values = c("Fidelity" = "#1f78b4", "Orthogonality" = "#33a02c")) +
243+
ggplot2::labs(title = paste("NSA-Flow Optimization Trace: ", retraction),
244+
subtitle = "Fidelity and Orthogonality Terms (Dual Scales)",
245+
x = "Iteration", color = "Term") +
246+
ggplot2::theme_minimal(base_size = 14) +
247+
ggplot2::theme(plot.title = ggplot2::element_text(face = "bold", hjust = 0.5),
248+
plot.subtitle = ggplot2::element_text(hjust = 0.5),
249+
legend.position = "top",
250+
panel.grid.major = ggplot2::element_line(color = "gray80"),
251+
panel.grid.minor = ggplot2::element_line(color = "gray90"),
252+
axis.title.y.left = ggplot2::element_text(color = "#1f78b4"),
253+
axis.text.y.left = ggplot2::element_text(color = "#1f78b4"),
254+
axis.title.y.right = ggplot2::element_text(color = "#33a02c"),
255+
axis.text.y.right = ggplot2::element_text(color = "#33a02c"))
256+
}
257+
Y = as.matrix(res$Y$detach()$numpy())
258+
rownames(Y) <- rownames(Y0)
259+
colnames(Y) <- colnames(Y0)
260+
list(
261+
Y=Y,
262+
energy = reticulate::py_to_r(res$best_total_energy),
263+
traces = trace_df,
264+
iter = reticulate::py_to_r(res$final_iter),
265+
plot = if (plot) energy_plot else NULL
266+
)
267+
}
268+
269+
270+
271+
272+
#' @title NSA-Flow Optimization via PyTorch AutoGrad
273+
#'
274+
#' @description
275+
#' Performs optimization to balance fidelity to a target matrix and orthogonality
276+
#' of the solution matrix using a weighted objective function. The function supports multiple retraction methods and includes robust convergence checks. These constraints provide global control over column-wise sparseness by projecting the matrix onto the approximate Stiefel manifold.
277+
#'
278+
#' @param Y0 Numeric matrix of size \code{p x k}, the initial guess for the solution.
279+
#' @param X0 Numeric matrix of size \code{p x k}, the target matrix for fidelity.
280+
#' If \code{NULL}, initialized as \code{pmax(Y0, 0)} with a small perturbation added to \code{Y0}.
281+
#' @param w Numeric scalar in \code{[0,1]}, weighting the trade-off between fidelity
282+
#' (1 - w) and orthogonality (w). Default is 0.5.
283+
#' @param retraction Character string specifying the retraction method to enforce
284+
#' orthogonality constraints.
285+
#' @param max_iter Integer, maximum number of iterations. Default is 100.
286+
#' @param tol Numeric, tolerance for convergence based on relative gradient norm
287+
#' and energy stability. Default is 1e-6.
288+
#' @param verbose Logical, if \code{TRUE}, prints iteration details. Default is \code{FALSE}.
289+
#' @param seed Integer, random seed for reproducibility. If \code{NULL}, no seed is set.
290+
#' Default is 42.
291+
#' @param apply_nonneg Logical, if \code{TRUE}, enforces non-negativity on the solution
292+
#' after retraction. Default is \code{TRUE}.
293+
#' @param optimizer Character string, optimization algorithm to use. The "fast" option
294+
#' will select the best option based on whether simplified = TRUE or FALSE.
295+
#' otherwise, pass the names of optimizers supported by \code{create_optimizer()}
296+
#' as seen in \code{list_simlr_optimizers()}. Default is "fast".
297+
#' @param initial_learning_rate Numeric, initial learning rate for the optimizer.
298+
#' Default is 1e-3 for non-neg and 1 for unconstrained. Otherwise, you can use \code{estimate_learning_rate_nsa()} to find a robust value.
299+
#'. pass a string one of c("brent", "grid", "armijo", "golden", "adaptive") to engage this method.
300+
#' @param record_every Integer, frequency of recording iteration metrics.
301+
#' Default is 1 (record every iteration).
302+
#' @param window_size Integer, size of the window for energy stability convergence check.
303+
#' Default is 5.
304+
#' @param c1_armijo Numeric, Armijo condition constant for line search.
305+
#' @param simplified Logical, if \code{TRUE}, uses the simplified objective
306+
#' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}.
307+
#' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}.
308+
#' @param project_full_gradient Logical, if \code{TRUE}, projects the full gradient instead
309+
#' of just the orthogonal component. Default is \code{FALSE}.
310+
#' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality
311+
#' traces with dual axes. Default is \code{FALSE}.
312+
#' @param precision Character string, either 'float32' or 'float64' to set the precision
313+
#'
314+
#' @return A list containing:
315+
#' \itemize{
316+
#' \item \code{Y}: Numeric matrix, the best solution found (lowest total energy).
317+
#' \item \code{traces}: Data frame with columns \code{iter}, \code{time},
318+
#' \code{fidelity}, \code{orthogonality}, and \code{total_energy}
319+
#' for recorded iterations.
320+
#' \item \code{final_iter}: Integer, number of iterations performed.
321+
#' \item \code{plot}: ggplot object of the optimization trace
322+
#' (if \code{plot = TRUE}), otherwise \code{NULL}.
323+
#' \item \code{best_total_energy}: Numeric, the lowest total energy achieved.
324+
#' }
325+
#'
326+
#' @details
327+
#' The function minimizes a weighted objective combining fidelity to \code{X0} and
328+
#' orthogonality of \code{Y}, defined as:
329+
#' \deqn{E(Y) = (1 - w) * ||Y - X0||_F^2 / (2 * p * k) + w * defect(Y)}
330+
#' where \code{defect(Y)} measures orthogonality deviation.
331+
#'
332+
#' The optimization uses a Riemannian gradient descent approach with optional
333+
#' retraction to enforce orthogonality constraints. Convergence is checked via
334+
#' relative gradient norm and energy stability over a window of iterations.
335+
#'
336+
#' @examples
337+
#' set.seed(123)
338+
#' Y0 <- matrix(runif(20), 5, 4)
339+
#' X0 <- matrix(runif(20), 5, 4)
340+
#' # The original function relies on helper functions not shown here, such as:
341+
#' # create_optimizer, step, inv_sqrt_sym, symm, and invariant_orthogonality_defect.
342+
#' # The following example is conceptual:
343+
#' # result <- nsa_flow_torch(Y0, X0, w = 0.0, max_iter = 10, verbose = TRUE, plot = TRUE)
344+
#' # print(result$plot)
345+
#' # print(result$traces)
346+
#'
347+
#' @import ggplot2
348+
#' @import reshape2
349+
#' @export
350+
nsa_flow_torch_ag <- function(
351+
Y0, X0 = NULL, w = 0.5,
352+
retraction = c( "soft_polar", "polar", "none" ),
353+
max_iter = 500, tol = 1e-5, verbose = FALSE, seed = 42,
354+
apply_nonneg = TRUE, optimizer = "fast",
355+
initial_learning_rate = 'default',
356+
record_every = 1, window_size = 5, c1_armijo=1e-6,
357+
simplified = FALSE,
358+
project_full_gradient = FALSE,
359+
plot = FALSE,
360+
precision = 'float64'
361+
) {
362+
if (!is.matrix(Y0)) {
363+
stop("Y0 must be a numeric matrix.")
364+
}
365+
p <- nrow(Y0)
366+
k <- ncol(Y0)
367+
368+
if (is.null(X0)) {
369+
if ( apply_nonneg ) X0 <- pmax(Y0, 0) else X0 = Y0
370+
perturb_scale <- sqrt(sum(Y0^2)) / sqrt(length(Y0)) * 0.05
371+
Y0 <- Y0 + matrix(rnorm(nrow(Y0) * ncol(Y0), sd = perturb_scale), nrow(Y0), ncol(Y0))
372+
if (verbose) cat("Added perturbation to Y0\n")
373+
} else {
374+
if ( apply_nonneg ) X0 <- pmax(X0, 0)
375+
if (nrow(X0) != nrow(Y0) || ncol(X0) != ncol(Y0)) stop("X0 must have same dimensions as Y0")
376+
}
377+
378+
retraction_type <- match.arg(retraction)
379+
380+
# Fast ortho terms (used in gradients; optional c_orth scaling)
381+
compute_ortho_terms <- function(Y, c_orth = 1, simplified = FALSE ) {
382+
norm2 <- sum(Y^2)
383+
if (norm2 <= 1e-12 || c_orth <= 0) {
384+
return(list(grad_orth = matrix(0, nrow(Y), ncol(Y)), defect = 0, norm2 = norm2))
385+
}
386+
S <- crossprod(Y) # Once!
387+
diagS <- diag(S)
388+
off_f2 <- sum(S * S) - sum(diagS^2)
389+
defect <- off_f2 / norm2^2
390+
Y_S <- Y %*% S
391+
Y_diag_scale <- sweep(Y, 2, diagS, "*") # Columns of Y scaled by diagS
392+
term1 <- (Y_S - Y_diag_scale) / norm2^2
393+
term2 <- (defect / norm2) * Y
394+
if ( simplified ) {
395+
grad_orth <- - c_orth * 2 * Y %*% (S - diag(ncol(Y)))
396+
} else {
397+
grad_orth <- c_orth * (term1 - term2)
398+
}
399+
list(grad_orth = grad_orth, defect = defect, norm2 = norm2)
400+
}
401+
402+
Y <- Y0
403+
# --- Compute initial scales ---
404+
g0 <- 0.5 * sum((Y0 - X0)^2) / (p * k)
405+
if (g0 < 1e-8) g0 <- 1e-8
406+
d0 <- invariant_orthogonality_defect(Y0) # Fast!
407+
if (d0 < 1e-8) d0 <- 1e-8
408+
# --- Weighting terms ---
409+
fid_eta <- (1 - w) / (g0 * p * k)
410+
c_orth <- 4 * w / d0
411+
fid_eta_pt5 <- (1 - 0.5) / (g0 * p * k)
412+
c_orth_pt5 <- 4 * 0.5 / d0
413+
trace <- list()
414+
recent_energies <- numeric(0)
415+
t0 <- Sys.time()
416+
# --- Track best solution ---
417+
best_Y <- Y
418+
best_total_energy <- Inf
419+
# --- Compute initial gradient for relative norm tolerance ---
420+
grad_fid_init <- fid_eta * (Y - X0) * (-1.0)
421+
ortho_init <- compute_ortho_terms(Y, c_orth, simplified=simplified)
422+
grad_orth_init <- ortho_init$grad_orth
423+
if (c_orth > 0) {
424+
sym_term_orth_init <- symm(crossprod(Y, grad_orth_init)) # t(Y) %*% = crossprod(Y, .)
425+
rgrad_orth_init <- grad_orth_init - Y %*% sym_term_orth_init
426+
} else {
427+
rgrad_orth_init <- grad_orth_init
428+
}
429+
rgrad_init <- grad_fid_init + rgrad_orth_init
430+
init_grad_norm <- sqrt(sum(rgrad_init^2)) + 1e-8
431+
nsa_energy_pt5 <- function(Vp) {
432+
# --- Retraction ---
433+
Vp <- nsa_flow_retract_auto(Vp, 0.5, retraction)
434+
# --- Optional non-negativity ---
435+
Vp <- if (apply_nonneg) pmax(Vp, 0) else Vp
436+
e <- 0.5 * fid_eta_pt5 * sum((Vp - X0)^2)
437+
if (c_orth_pt5 > 0) {
438+
norm2_V <- sum(Vp^2)
439+
if (norm2_V > 0) {
440+
defect <- invariant_orthogonality_defect(Vp) # Fast!
441+
e <- e + 0.25 * c_orth_pt5 * defect
442+
}
443+
}
444+
e
445+
}
446+
447+
# --- Optimizer initialization ---
448+
if (is.null(initial_learning_rate) || is.na(initial_learning_rate)) {
449+
initial_learning_rate <- "brent"
450+
}
451+
if (is.character(initial_learning_rate)) {
452+
if (verbose)
453+
cat("Estimating robust initial learning rate using optim()...\n")
454+
lr_res <- estimate_learning_rate_nsa(Y0, X0, w = w,
455+
retraction = retraction, nsa_energy = nsa_energy_pt5,
456+
apply_nonneg = apply_nonneg, method = initial_learning_rate,
457+
verbose = verbose, plot = FALSE)$estimated_learning_rate
458+
459+
} else {
460+
lr_res <- initial_learning_rate
461+
}
462+
463+
print(lr_res)
464+
465+
466+
torch <- reticulate::import("torch", convert = FALSE)
467+
pynsa <- reticulate::import("nsa_flow", convert=FALSE )
468+
469+
470+
if (is.null(pynsa) ) {
471+
stop("Could not find Python package `nsa_flow` -- please install it first.")
472+
}
473+
474+
Y_torch <- torch$tensor(Y0, dtype = torch$float64)
475+
Xc_torch <- torch$tensor(X0, dtype = torch$float64)
476+
477+
res <- pynsa$nsa_flow_autograd(
478+
Y_torch,
479+
Xc_torch,
480+
w = w,
481+
retraction = retraction,
482+
max_iter = as.integer(max_iter),
483+
tol = tol,
484+
verbose = verbose,
485+
seed = as.integer(seed),
486+
apply_nonneg = apply_nonneg,
487+
optimizer = optimizer,
488+
initial_learning_rate = lr_res,
489+
record_every = as.integer(record_every),
490+
window_size = as.integer(window_size),
491+
precision=precision
492+
)
493+
494+
495+
trace_df = as.data.frame(reticulate::py_to_r(res$traces))
247496

248497
if (plot && !is.null(trace_df) && nrow(trace_df) > 0) {
249498
max_fid <- max(trace_df$fidelity, na.rm = TRUE)

0 commit comments

Comments
 (0)