-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_Hessian.py
80 lines (61 loc) · 1.92 KB
/
test_Hessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import jax
import jax.numpy as jnp
from jax import config
from jaxtyping import Array, Float
config.update("jax_enable_x64", True)
import lineax as lx
import time
import timeit
from jax.example_libraries import optimizers as jax_opt
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
import numpy as np
import SinkhornHessian
import util
import matplotlib
import matplotlib.pyplot as plt
import mpl_toolkits.axes_grid1
#%%
n = 5000
m = 10000
epsilon = 0.0025
dim = 5
threshold = 0.01 / (n**0.33)
tau2 =1e-5
iter = 10
mu, nv1, x, y1 = util.sample_points_uniform(n, n, dim, 4)
mu1, nv, x1, y = util.sample_points_uniform(m, m, dim, 10)
#y = x
xT = x.T
yT = y.T
#%% solve the optimal transport problem
geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=16)
prob = linear_problem.LinearProblem(geom, a=mu, b=nv)
solver = sinkhorn.Sinkhorn(
threshold=threshold, use_danskin=False, max_iterations=200000
#solve_kwargs={
#"implicit_diff": imp_diff.ImplicitDiff() if implicit else None}
)
out = solver(prob)
#%% generate random matrix A
A = np.random.randn(n, dim)
#%% materialize the hessian
svd_thr = 1e-10
SH = SinkhornHessian.SinkhornHessian(svd_thr)
#H = SH.LHS_matrix(out)
T = SH.compute_hessian(out)
result1 = jnp.tensordot(T, A, axes=((2,3), (0,1)))
#%% compute the hessian dot A without materializing the hessian
# without preconditioning
result2 = SinkhornHessian.HessianA(A,out, tau2, iter)
# with preconditioning
result3 = SinkhornHessian.HessianAPrecond(A,out, tau2, iter)
#%% compare the results
print("The max error between the true Hessian and without preconditioning is:")
print(jnp.max(jnp.abs(result1 - result2)))
print("The max error between the true Hessian and with preconditioning is:")
print(jnp.max(jnp.abs(result1 - result3)))