Skip to content

Commit 6e518c9

Browse files
committed
ENH: ....
1 parent 04daf8e commit 6e518c9

File tree

3 files changed

+169
-26
lines changed

3 files changed

+169
-26
lines changed

R/nsa_flow_torch.R

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#' @param simplified Logical, if \code{TRUE}, uses the simplified objective
3535
#' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}.
3636
#' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}.
37+
#' @param project_full_gradient Logical, if \code{TRUE}, projects the full gradient instead
38+
#' of just the orthogonal component. Default is \code{FALSE}.
3739
#' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality
3840
#' traces with dual axes. Default is \code{FALSE}.
3941
#'
@@ -81,12 +83,23 @@ nsa_flow_torch <- function(
8183
initial_learning_rate = 'default',
8284
record_every = 1, window_size = 5, c1_armijo=1e-6,
8385
simplified = FALSE,
86+
project_full_gradient = FALSE,
8487
plot = FALSE
8588
) {
8689
if (!is.matrix(Y0)) {
8790
stop("Y0 must be a numeric matrix.")
8891
}
8992

93+
if (is.null(X0)) {
94+
if ( apply_nonneg ) X0 <- pmax(Y0, 0) else X0 = Y0
95+
perturb_scale <- sqrt(sum(Y0^2)) / sqrt(length(Y0)) * 0.05
96+
Y0 <- Y0 + matrix(rnorm(nrow(Y0) * ncol(Y0), sd = perturb_scale), nrow(Y0), ncol(Y0))
97+
if (verbose) cat("Added perturbation to Y0\n")
98+
} else {
99+
if ( apply_nonneg ) X0 <- pmax(X0, 0)
100+
if (nrow(X0) != nrow(Y0) || ncol(X0) != ncol(Y0)) stop("X0 must have same dimensions as Y0")
101+
}
102+
90103
retraction_type <- match.arg(retraction)
91104

92105
torch <- reticulate::import("torch", convert = FALSE)
@@ -115,14 +128,66 @@ nsa_flow_torch <- function(
115128
record_every = as.integer(record_every),
116129
window_size = as.integer(window_size),
117130
simplified = simplified,
131+
project_full_gradient = project_full_gradient
118132
#
119133
# armijo_beta = armijo_beta,
120134
# armijo_c = armijo_c,
121135
)
122136

137+
138+
df = as.data.frame(reticulate::py_to_r(res$traces))
139+
# Suppose your data frame is named df
140+
cols <- c("iter", "time", "fidelity", "orthogonality", "total_energy")
141+
# Find all iteration suffixes
142+
suffixes <- unique(gsub(".*\\.", "", grep("\\.", names(df), value = TRUE)))
143+
suffixes <- c("", sort(unique(suffixes))) # include first iteration (no suffix)
144+
145+
# Build rows for each suffix
146+
rows <- lapply(suffixes, function(suf) {
147+
postfix <- if (suf == "") "" else paste0(".", suf)
148+
subset <- df[ , paste0(cols, postfix), drop = FALSE]
149+
names(subset) <- cols
150+
subset
151+
})
152+
153+
# Bind all iterations together
154+
trace_df <- do.call(rbind, rows)
155+
rownames(trace_df) <- NULL
156+
157+
if (plot && !is.null(trace_df) && nrow(trace_df) > 0) {
158+
max_fid <- max(trace_df$fidelity, na.rm = TRUE)
159+
max_orth <- max(trace_df$orthogonality, na.rm = TRUE)
160+
ratio <- if (max_orth > 0) max_fid / max_orth else 1
161+
energy_plot <- ggplot2::ggplot(trace_df, ggplot2::aes(x = iter)) +
162+
ggplot2::geom_line(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.2) +
163+
ggplot2::geom_point(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.5, alpha = 0.7) +
164+
ggplot2::geom_line(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.2) +
165+
ggplot2::geom_point(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.5, alpha = 0.7) +
166+
ggplot2::scale_y_continuous(name = "Fidelity Energy",
167+
sec.axis = ggplot2::sec_axis(~ . / ratio, name = "Orthogonality Defect")) +
168+
ggplot2::scale_color_manual(values = c("Fidelity" = "#1f78b4", "Orthogonality" = "#33a02c")) +
169+
ggplot2::labs(title = paste("NSA-Flow Optimization Trace: ", retraction),
170+
subtitle = "Fidelity and Orthogonality Terms (Dual Scales)",
171+
x = "Iteration", color = "Term") +
172+
ggplot2::theme_minimal(base_size = 14) +
173+
ggplot2::theme(plot.title = ggplot2::element_text(face = "bold", hjust = 0.5),
174+
plot.subtitle = ggplot2::element_text(hjust = 0.5),
175+
legend.position = "top",
176+
panel.grid.major = ggplot2::element_line(color = "gray80"),
177+
panel.grid.minor = ggplot2::element_line(color = "gray90"),
178+
axis.title.y.left = ggplot2::element_text(color = "#1f78b4"),
179+
axis.text.y.left = ggplot2::element_text(color = "#1f78b4"),
180+
axis.title.y.right = ggplot2::element_text(color = "#33a02c"),
181+
axis.text.y.right = ggplot2::element_text(color = "#33a02c"))
182+
}
183+
Y = as.matrix(res$Y$detach()$numpy())
184+
rownames(Y) <- rownames(Y0)
185+
colnames(Y) <- colnames(Y0)
123186
list(
124-
Y = as.matrix(res$Y$detach()$numpy()),
187+
Y=Y,
125188
energy = reticulate::py_to_r(res$best_total_energy),
126-
iter = reticulate::py_to_r(res$final_iter)
189+
traces = trace_df,
190+
iter = reticulate::py_to_r(res$final_iter),
191+
plot = if (plot) energy_plot else NULL
127192
)
128193
}

R/zzz.R

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@
55

66

77
.onLoad <- function(libname, pkgname) {
8-
py_file <- system.file("python", "nsa_flow_py.py", package = pkgname)
9-
if (file.exists(py_file)) {
10-
assign("nsa_env", reticulate::source_python(py_file, convert = FALSE),
11-
envir = parent.env(environment()))
12-
} else {
13-
warning("Could not find nsa_flow_py.py in package")
14-
}
8+
invisible()
159
}
1610

man/nsa_flow_torch.Rd

Lines changed: 101 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)