Skip to content

Commit b5c1476

Browse files
committed
WIP: torch nsa flow
1 parent e3d843f commit b5c1476

11 files changed

+1755
-0
lines changed

R/nsa_flow_torch.R

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#' Run NSA-Flow via PyTorch backend
2+
#'
3+
#' @param Y A numeric matrix
4+
#' @param Xc A numeric matrix
5+
#' @param w_pca,lambda,lr,max_iter,retraction_type,armijo_beta,armijo_c,tol,verbose Control parameters
6+
#'
7+
#' @return A list with elements `Y`, `energy`, and `iter`
8+
#' @export
9+
nsa_flow_torch <- function(
10+
Y, Xc,
11+
w_pca = 1.0, lambda = 0.01,
12+
lr = 1e-2, max_iter = 100L,
13+
retraction_type = c("soft_polar", "polar", "none"),
14+
armijo_beta = 0.5, armijo_c = 1e-4,
15+
tol = 1e-6, verbose = FALSE
16+
) {
17+
retraction_type <- match.arg(retraction_type)
18+
19+
torch <- reticulate::import("torch", convert = FALSE)
20+
pynsa <- reticulate::import("nsa_flow", convert=FALSE )
21+
22+
23+
if (is.null(pynsa) ) {
24+
stop("Could not find Python function `nsa_flow_py` in the 'nsa_flow_py' module.")
25+
}
26+
27+
Y_torch <- torch$tensor(Y, dtype = torch$float64)
28+
Xc_torch <- torch$tensor(Xc, dtype = torch$float64)
29+
30+
res <- pynsa$nsa_flow_py(
31+
Y_torch, Xc_torch,
32+
w_pca = w_pca,
33+
lambda_ = lambda,
34+
lr = lr,
35+
max_iter = as.integer(max_iter),
36+
retraction_type = retraction_type,
37+
armijo_beta = armijo_beta,
38+
armijo_c = armijo_c,
39+
tol = tol,
40+
verbose = verbose
41+
)
42+
43+
list(
44+
Y = as.matrix(res$Y$numpy()),
45+
energy = reticulate::py_to_r(res$energy),
46+
iter = reticulate::py_to_r(res$iter)
47+
)
48+
}

man/apply_transform_matrix.Rd

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/estimate_learning_rate_nsa.Rd

Lines changed: 54 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_fa.Rd

Lines changed: 87 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_fa_diagram.Rd

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_pca.Rd

Lines changed: 75 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_retract_auto.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nsa_flow_torch.Rd

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)