@@ -533,3 +533,228 @@ nsa_flow_torch_ag <- function(
533533 plot = if (plot ) energy_plot else NULL
534534 )
535535}
536+
537+
538+
539+ # ' @title NSA-Flow Optimization via PyTorch AutoGrad (R wrapper)
540+ # '
541+ # ' @description
542+ # ' Mirror of Python `nsa_flow_autograd()` but callable from R via reticulate.
543+ # ' Uses the Python implementation (nsa_flow.nsa_flow_autograd) and returns
544+ # ' R-friendly results (matrix, data.frame, ggplot).
545+ # '
546+ # ' @param Y0 numeric matrix p x k initial guess
547+ # ' @param X0 numeric matrix p x k target (or NULL to initialize from Y0)
548+ # ' @param w numeric in [0,1] weighting fidelity vs orthogonality
549+ # ' @param retraction character retraction method
550+ # ' @param max_iter integer max iterations
551+ # ' @param tol numeric convergence tolerance
552+ # ' @param verbose logical
553+ # ' @param seed integer random seed
554+ # ' @param apply_nonneg logical or 'softplus'/'none' etc.
555+ # ' @param optimizer character optimizer name (e.g. 'Adam','lars','sgdp')
556+ # ' @param initial_learning_rate NULL (auto), numeric, or character strategy string
557+ # ' @param lr_strategy character passed to Python if initial_learning_rate is NULL/'auto'
558+ # ' @param fidelity_type character ('basic','scale_invariant','symmetric','normalized')
559+ # ' @param orth_type character ('basic','scale_invariant')
560+ # ' @param record_every integer frequency of recording traces
561+ # ' @param window_size integer window for energy stability
562+ # ' @param plot logical produce ggplot (default FALSE)
563+ # ' @param precision 'float32' or 'float64'
564+ # ' @return list: Y (matrix), traces (data.frame), final_iter, best_total_energy, best_Y_iteration, plot (ggplot or NULL), settings
565+ # ' @export
566+ nsa_flow_autograd <- function (
567+ Y0 ,
568+ X0 = NULL ,
569+ w = 0.5 ,
570+ retraction = c(" soft_polar" , " polar" , " none" ),
571+ max_iter = 500L ,
572+ tol = 1e-6 ,
573+ verbose = FALSE ,
574+ seed = 42L ,
575+ apply_nonneg = TRUE ,
576+ optimizer = " Adam" ,
577+ initial_learning_rate = NULL ,
578+ lr_strategy = " auto" ,
579+ fidelity_type = " scale_invariant" ,
580+ orth_type = " scale_invariant" ,
581+ record_every = 1L ,
582+ window_size = 5L ,
583+ plot = FALSE ,
584+ precision = " float64"
585+ ) {
586+ # basic checks
587+ if (! is.matrix(Y0 )) stop(" Y0 must be a numeric matrix." )
588+ if (! is.null(X0 ) && ! is.matrix(X0 )) stop(" X0 must be a numeric matrix or NULL." )
589+ retraction <- match.arg(retraction )
590+
591+ # import python modules
592+ reticulate :: py_config() # (optional debug) - remove if noisy
593+ torch <- reticulate :: import(" torch" , convert = FALSE )
594+ pynsa <- reticulate :: import(" nsa_flow" , convert = FALSE )
595+
596+ if (is.null(pynsa )) stop(" Could not import python package 'nsa_flow'" )
597+
598+ # convert Y0 and X0 to torch tensors (no convert)
599+ dtype_str <- ifelse(precision == " float32" , " float32" , " float64" )
600+ torch_dtype <- if (dtype_str == " float32" ) torch $ float32 else torch $ float64
601+
602+ Y_torch <- torch $ tensor(Y0 , dtype = torch_dtype )
603+ X_torch <- if (is.null(X0 )) reticulate :: r_to_py(NULL ) else torch $ tensor(X0 , dtype = torch_dtype )
604+
605+ # prepare initial_learning_rate argument for Python call
606+ if (is.null(initial_learning_rate )) {
607+ py_init_lr <- reticulate :: r_to_py(NULL )
608+ py_lr_strategy <- lr_strategy
609+ } else if (is.character(initial_learning_rate )) {
610+ # pass through strategy string (e.g. "armijo", "auto")
611+ py_init_lr <- reticulate :: r_to_py(initial_learning_rate )
612+ py_lr_strategy <- initial_learning_rate
613+ } else if (is.numeric(initial_learning_rate )) {
614+ py_init_lr <- reticulate :: r_to_py(as.double(initial_learning_rate ))
615+ py_lr_strategy <- NA_character_
616+ } else {
617+ stop(" initial_learning_rate must be NULL, a numeric, or a strategy string." )
618+ }
619+
620+ # normalize apply_nonneg options: allow TRUE, FALSE, "softplus", "none"
621+ apply_nonneg_py <- switch (
622+ as.character(apply_nonneg ),
623+ " TRUE" = reticulate :: r_to_py(TRUE ),
624+ " FALSE" = reticulate :: r_to_py(FALSE ),
625+ " softplus" = reticulate :: r_to_py(" softplus" ),
626+ " none" = reticulate :: r_to_py(" none" ),
627+ reticulate :: r_to_py(apply_nonneg )
628+ )
629+
630+ # Call python function; be defensive about argument names (match Python signature)
631+ py_args <- list (
632+ Y0 = Y_torch ,
633+ X0 = X_torch ,
634+ w = as.double(w ),
635+ retraction = retraction ,
636+ max_iter = as.integer(max_iter ),
637+ tol = as.double(tol ),
638+ verbose = reticulate :: r_to_py(verbose ),
639+ seed = as.integer(seed ),
640+ apply_nonneg = apply_nonneg_py ,
641+ optimizer = optimizer ,
642+ initial_learning_rate = py_init_lr ,
643+ lr_strategy = lr_strategy ,
644+ fidelity_type = fidelity_type ,
645+ orth_type = orth_type ,
646+ record_every = as.integer(record_every ),
647+ window_size = as.integer(window_size ),
648+ precision = precision
649+ )
650+
651+ # Try to call python function and handle errors clearly
652+ res_py <- tryCatch(
653+ {
654+ do.call(pynsa $ nsa_flow_autograd , py_args )
655+ },
656+ error = function (e ) {
657+ stop(" Error calling Python nsa_flow_autograd():\n " , e $ message )
658+ }
659+ )
660+
661+ # Convert outputs
662+ # Expect res_py to be a dict-like object with keys "Y", "traces", "final_iter", "best_total_energy", "best_Y_iteration", "target", "settings"
663+ # Convert Y (torch tensor) -> numeric matrix
664+ Y_out <- NULL
665+ if (! is.null(res_py $ Y )) {
666+ # res_py$Y might be a torch tensor; convert safely
667+ # Use detach() if available
668+ try({
669+ # If it's a tensor object, call detach().numpy(); else, try py_to_r
670+ if (! is.null(res_py $ Y $ detach )) {
671+ Y_out <- as.matrix(res_py $ Y $ detach()$ cpu()$ numpy())
672+ } else {
673+ Y_out <- reticulate :: py_to_r(res_py $ Y )
674+ }
675+ }, silent = TRUE )
676+
677+ # fallback
678+ if (is.null(Y_out )) {
679+ Y_out <- tryCatch(reticulate :: py_to_r(res_py $ Y ), error = function (e ) NULL )
680+ }
681+ }
682+
683+ # Convert traces to data.frame if possible
684+ traces_df <- NULL
685+ if (! is.null(res_py $ traces )) {
686+ traces_df <- tryCatch({
687+ reticulate :: py_to_r(res_py $ traces )
688+ }, error = function (e ) {
689+ # If traces is a list of dicts, convert manually
690+ tlist <- reticulate :: py_to_r(res_py $ traces )
691+ if (is.list(tlist ) && length(tlist ) > 0 && is.list(tlist [[1 ]])) {
692+ do.call(rbind , lapply(tlist , function (x ) as.data.frame(x , stringsAsFactors = FALSE )))
693+ } else {
694+ as.data.frame(tlist )
695+ }
696+ })
697+ # ensure rownames removed
698+ rownames(traces_df ) <- NULL
699+ }
700+
701+ # final iter & energy
702+ final_iter <- tryCatch(reticulate :: py_to_r(res_py $ final_iter ), error = function (e ) NA )
703+ best_total_energy <- tryCatch(reticulate :: py_to_r(res_py $ best_total_energy ), error = function (e ) NA )
704+ best_Y_iteration <- tryCatch(reticulate :: py_to_r(res_py $ best_Y_iteration ), error = function (e ) NA )
705+ settings <- tryCatch(reticulate :: py_to_r(res_py $ settings ), error = function (e ) NULL )
706+
707+ # Build ggplot trace if requested and data available
708+ energy_plot <- NULL
709+ if (plot && ! is.null(traces_df ) && nrow(traces_df ) > 0 ) {
710+ if (! (" fidelity" %in% colnames(traces_df )) || ! (" orthogonality" %in% colnames(traces_df ))) {
711+ # try to coerce likely-named columns
712+ possible_fid <- grep(" fid" , names(traces_df ), value = TRUE , ignore.case = TRUE )
713+ possible_orth <- grep(" orth" , names(traces_df ), value = TRUE , ignore.case = TRUE )
714+ if (length(possible_fid ) > = 1 ) names(traces_df )[which(names(traces_df ) == possible_fid [1 ])] <- " fidelity"
715+ if (length(possible_orth ) > = 1 ) names(traces_df )[which(names(traces_df ) == possible_orth [1 ])] <- " orthogonality"
716+ }
717+
718+ if (" fidelity" %in% colnames(traces_df ) && " orthogonality" %in% colnames(traces_df )) {
719+ max_fid <- max(traces_df $ fidelity , na.rm = TRUE )
720+ max_orth <- max(traces_df $ orthogonality , na.rm = TRUE )
721+ ratio <- if (max_orth > 0 ) max_fid / max_orth else 1
722+
723+ energy_plot <- ggplot2 :: ggplot(traces_df , ggplot2 :: aes(x = iter )) +
724+ ggplot2 :: geom_line(ggplot2 :: aes(y = fidelity , color = " Fidelity" ), size = 1.1 ) +
725+ ggplot2 :: geom_point(ggplot2 :: aes(y = fidelity , color = " Fidelity" ), size = 1.2 , alpha = 0.7 ) +
726+ ggplot2 :: geom_line(ggplot2 :: aes(y = orthogonality * ratio , color = " Orthogonality" ), size = 1.1 ) +
727+ ggplot2 :: geom_point(ggplot2 :: aes(y = orthogonality * ratio , color = " Orthogonality" ), size = 1.2 , alpha = 0.7 ) +
728+ ggplot2 :: scale_y_continuous(name = " Fidelity Energy" ,
729+ sec.axis = ggplot2 :: sec_axis(~ . / ratio , name = " Orthogonality Defect" )) +
730+ ggplot2 :: scale_color_manual(values = c(" Fidelity" = " #1f78b4" , " Orthogonality" = " #33a02c" )) +
731+ ggplot2 :: labs(title = paste(" NSA-Flow Optimization Trace:" , retraction ),
732+ subtitle = paste0(" fidelity_type=" , fidelity_type , " , orth_type=" , orth_type ),
733+ x = " Iteration" , color = " Term" ) +
734+ ggplot2 :: theme_minimal(base_size = 13 ) +
735+ ggplot2 :: theme(plot.title = ggplot2 :: element_text(face = " bold" , hjust = 0.5 ),
736+ legend.position = " top" )
737+ }
738+ }
739+
740+ # Rescale target/back to original magnitude if Python returned scaled Y or target info:
741+ # If Python returned 'target' or 'settings' containing scale_ref, you could rescale;
742+ # here we assume outputs are in original scale or the Python function already rescaled.
743+ # Provide Y_out as numeric matrix; maintain dimnames from input
744+ if (! is.null(Y_out )) {
745+ dimnames(Y_out ) <- list (rownames(Y0 ), colnames(Y0 ))
746+ }
747+
748+ out <- list (
749+ Y = Y_out ,
750+ traces = traces_df ,
751+ final_iter = final_iter ,
752+ best_total_energy = best_total_energy ,
753+ best_Y_iteration = best_Y_iteration ,
754+ plot = if (plot ) energy_plot else NULL ,
755+ settings = settings
756+ )
757+
758+ class(out ) <- c(" nsa_flow_result" , class(out ))
759+ return (out )
760+ }
0 commit comments