|
34 | 34 | #' @param simplified Logical, if \code{TRUE}, uses the simplified objective |
35 | 35 | #' \deqn{\min_U (1 - w) \frac{1}{2} ||U - Z||_F^2 + w \frac{1}{2} ||U^\top U - I_k||_F^2}. |
36 | 36 | #' If \code{FALSE}, uses the invariant defect objective. Default is \code{FALSE}. |
| 37 | +#' @param project_full_gradient Logical, if \code{TRUE}, projects the full gradient instead |
| 38 | +#' of just the orthogonal component. Default is \code{FALSE}. |
37 | 39 | #' @param plot Logical, if \code{TRUE}, generates a ggplot of fidelity and orthogonality |
38 | 40 | #' traces with dual axes. Default is \code{FALSE}. |
39 | 41 | #' |
@@ -81,12 +83,23 @@ nsa_flow_torch <- function( |
81 | 83 | initial_learning_rate = 'default', |
82 | 84 | record_every = 1, window_size = 5, c1_armijo=1e-6, |
83 | 85 | simplified = FALSE, |
| 86 | + project_full_gradient = FALSE, |
84 | 87 | plot = FALSE |
85 | 88 | ) { |
86 | 89 | if (!is.matrix(Y0)) { |
87 | 90 | stop("Y0 must be a numeric matrix.") |
88 | 91 | } |
89 | 92 |
|
| 93 | + if (is.null(X0)) { |
| 94 | + if ( apply_nonneg ) X0 <- pmax(Y0, 0) else X0 = Y0 |
| 95 | + perturb_scale <- sqrt(sum(Y0^2)) / sqrt(length(Y0)) * 0.05 |
| 96 | + Y0 <- Y0 + matrix(rnorm(nrow(Y0) * ncol(Y0), sd = perturb_scale), nrow(Y0), ncol(Y0)) |
| 97 | + if (verbose) cat("Added perturbation to Y0\n") |
| 98 | + } else { |
| 99 | + if ( apply_nonneg ) X0 <- pmax(X0, 0) |
| 100 | + if (nrow(X0) != nrow(Y0) || ncol(X0) != ncol(Y0)) stop("X0 must have same dimensions as Y0") |
| 101 | + } |
| 102 | + |
90 | 103 | retraction_type <- match.arg(retraction) |
91 | 104 |
|
92 | 105 | torch <- reticulate::import("torch", convert = FALSE) |
@@ -115,14 +128,66 @@ nsa_flow_torch <- function( |
115 | 128 | record_every = as.integer(record_every), |
116 | 129 | window_size = as.integer(window_size), |
117 | 130 | simplified = simplified, |
| 131 | + project_full_gradient = project_full_gradient |
118 | 132 | # |
119 | 133 | # armijo_beta = armijo_beta, |
120 | 134 | # armijo_c = armijo_c, |
121 | 135 | ) |
122 | 136 |
|
| 137 | + |
| 138 | + df = as.data.frame(reticulate::py_to_r(res$traces)) |
| 139 | + # Suppose your data frame is named df |
| 140 | + cols <- c("iter", "time", "fidelity", "orthogonality", "total_energy") |
| 141 | + # Find all iteration suffixes |
| 142 | + suffixes <- unique(gsub(".*\\.", "", grep("\\.", names(df), value = TRUE))) |
| 143 | + suffixes <- c("", sort(unique(suffixes))) # include first iteration (no suffix) |
| 144 | + |
| 145 | + # Build rows for each suffix |
| 146 | + rows <- lapply(suffixes, function(suf) { |
| 147 | + postfix <- if (suf == "") "" else paste0(".", suf) |
| 148 | + subset <- df[ , paste0(cols, postfix), drop = FALSE] |
| 149 | + names(subset) <- cols |
| 150 | + subset |
| 151 | + }) |
| 152 | + |
| 153 | + # Bind all iterations together |
| 154 | + trace_df <- do.call(rbind, rows) |
| 155 | + rownames(trace_df) <- NULL |
| 156 | + |
| 157 | + if (plot && !is.null(trace_df) && nrow(trace_df) > 0) { |
| 158 | + max_fid <- max(trace_df$fidelity, na.rm = TRUE) |
| 159 | + max_orth <- max(trace_df$orthogonality, na.rm = TRUE) |
| 160 | + ratio <- if (max_orth > 0) max_fid / max_orth else 1 |
| 161 | + energy_plot <- ggplot2::ggplot(trace_df, ggplot2::aes(x = iter)) + |
| 162 | + ggplot2::geom_line(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.2) + |
| 163 | + ggplot2::geom_point(ggplot2::aes(y = fidelity, color = "Fidelity"), size = 1.5, alpha = 0.7) + |
| 164 | + ggplot2::geom_line(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.2) + |
| 165 | + ggplot2::geom_point(ggplot2::aes(y = orthogonality * ratio, color = "Orthogonality"), size = 1.5, alpha = 0.7) + |
| 166 | + ggplot2::scale_y_continuous(name = "Fidelity Energy", |
| 167 | + sec.axis = ggplot2::sec_axis(~ . / ratio, name = "Orthogonality Defect")) + |
| 168 | + ggplot2::scale_color_manual(values = c("Fidelity" = "#1f78b4", "Orthogonality" = "#33a02c")) + |
| 169 | + ggplot2::labs(title = paste("NSA-Flow Optimization Trace: ", retraction), |
| 170 | + subtitle = "Fidelity and Orthogonality Terms (Dual Scales)", |
| 171 | + x = "Iteration", color = "Term") + |
| 172 | + ggplot2::theme_minimal(base_size = 14) + |
| 173 | + ggplot2::theme(plot.title = ggplot2::element_text(face = "bold", hjust = 0.5), |
| 174 | + plot.subtitle = ggplot2::element_text(hjust = 0.5), |
| 175 | + legend.position = "top", |
| 176 | + panel.grid.major = ggplot2::element_line(color = "gray80"), |
| 177 | + panel.grid.minor = ggplot2::element_line(color = "gray90"), |
| 178 | + axis.title.y.left = ggplot2::element_text(color = "#1f78b4"), |
| 179 | + axis.text.y.left = ggplot2::element_text(color = "#1f78b4"), |
| 180 | + axis.title.y.right = ggplot2::element_text(color = "#33a02c"), |
| 181 | + axis.text.y.right = ggplot2::element_text(color = "#33a02c")) |
| 182 | + } |
| 183 | + Y = as.matrix(res$Y$detach()$numpy()) |
| 184 | + rownames(Y) <- rownames(Y0) |
| 185 | + colnames(Y) <- colnames(Y0) |
123 | 186 | list( |
124 | | - Y = as.matrix(res$Y$detach()$numpy()), |
| 187 | + Y=Y, |
125 | 188 | energy = reticulate::py_to_r(res$best_total_energy), |
126 | | - iter = reticulate::py_to_r(res$final_iter) |
| 189 | + traces = trace_df, |
| 190 | + iter = reticulate::py_to_r(res$final_iter), |
| 191 | + plot = if (plot) energy_plot else NULL |
127 | 192 | ) |
128 | 193 | } |
0 commit comments