From 3d006ff67269eeccbdf49dc04002ac433cc550ec Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 18:54:08 +0100 Subject: [PATCH 1/7] feat: Add fixed point solver for optimization, with tests comparing to jax Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> Co-authored-by: aseyboldt <aseyboldt@users.noreply.github.com> Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com> Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/optimise/__init__.py | 0 pytensor/optimise/fixed_point.py | 134 +++++++++++++++++++++++++++++ tests/optimise/__init__.py | 0 tests/optimise/test_fixed_point.py | 84 ++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 pytensor/optimise/__init__.py create mode 100644 pytensor/optimise/fixed_point.py create mode 100644 tests/optimise/__init__.py create mode 100644 tests/optimise/test_fixed_point.py diff --git a/pytensor/optimise/__init__.py b/pytensor/optimise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py new file mode 100644 index 0000000000..e9732e1be6 --- /dev/null +++ b/pytensor/optimise/fixed_point.py @@ -0,0 +1,134 @@ +from functools import partial + +import pytensor +import pytensor.tensor as pt +from pytensor.scan.utils import until + + +def _check_convergence(f_x, tol): + # TODO What convergence criterion? Norm of grad etc... + converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol) + return converged + + +def fwd_solver(x_prev, *args, func, tol): + x = func(x_prev, *args) + is_converged = _check_convergence(x - x_prev, tol) + return x, is_converged + + +def newton_solver(x_prev, *args, func, tol): + f_root = func(x_prev, *args) - x_prev + jac = pt.jacobian(f_root, x_prev) + + # TODO It would be nice to return the factored matrix for the pullback + # TODO Handle errors of the factorization + # 1D: x - f(x) / f'(x) + # general: x - J^-1 f(x) + + # grad = pt.linalg.solve(jac, f_root, assume_a="sym") + grad = pt.linalg.solve(jac, f_root) + x = x_prev - grad + + is_converged = _check_convergence(x - x_prev, tol) + + return x, is_converged + + +def fixed_point_solver( + f: callable, + solver: callable, + x0: pt.TensorVariable, + *args: tuple[pt.Variable, ...], + max_iter: int = 1000, + tol: float = 1e-5, +): + args = [pt.as_tensor(arg) for arg in args] + print(len(args)) + + def _scan_step(x, *args, func, solver, tol): + print(x.type) + x, is_converged = solver(x, *args, func=func, tol=tol) + print(x.type) + return x, until(is_converged) + + partial_step = partial( + _scan_step, + func=f, + solver=solver, + tol=tol, + ) + + x_sequence, updates = pytensor.scan( + partial_step, + outputs_info=[x0], + non_sequences=list(args), + n_steps=max_iter, + strict=True, + ) + + assert not updates + + x = x_sequence[-1] + return x, x_sequence.shape[0] + + +# %% +# x_star, n_steps = fixed_point_solver( +# g, +# fwd_solver, +# pt.zeros_like(b), +# W, b, +# ) +# print(x_star.eval(), n_steps.eval()) + +# %% +# x_star, n_steps = fixed_point_solver( +# g, +# newton_solver, +# pt.zeros_like(b), +# W, b, +# max_n_steps=10, +# ) +# print(x_star.eval(), n_steps.eval()) + + +# def _newton_solver(x_prev, *args, func, tol): +# f_root = lambda x: func(x) - x +# g = lambda x: x - pt.linalg.solve(pt.jacobian(f_root)(x), f_root(x)) +# return fwd_solver(g, x) + +# # %% +# def jax_newton_solver(f, z_init): +# f_root = lambda z: f(z) - z +# grad = jax.numpy.linalg.solve(jax.jacobian(f_root)(z_init), f_root(z_init)) +# # print(np.linalg.solve(grad, f_root_z)) +# # print(sp.linalg.solve(grad, f_root_z)) +# # print(jax.numpy.linalg.solve(grad, f_root_z)) +# # print(pt.linalg.solve(grad, f_root_z).eval()) +# g = lambda z: z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) +# return jax_fwd_solver(g, z_init) +# # return grad + +# def jax_fwd_solver(f, z_init): +# z_prev, z = z_init, f(z_init) +# i = 1 +# while jax.numpy.linalg.norm(z_prev - z) > 1e-5: +# z_prev, z = z, f(z) +# i += 1 +# print(i) +# return z + +# def _jax_g(x, W, b): +# return jax.numpy.tanh(jax.numpy.dot(W, x) + b) + + +# jax_g = partial(_jax_g, W=W, b=b) + +# print(jax_newton_solver(jax_g, jax.numpy.zeros_like(b))) + +# Array([-0.02879991, -0.8708013 , -1.4001148 , -0.1013868 , -0.641474 , +# -0.7552165 , 0.62554246, 0.9438805 , -0.05192749, 1.430574 ], dtype=float32) + + +# %% diff --git a/tests/optimise/__init__.py b/tests/optimise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py new file mode 100644 index 0000000000..90860ae15b --- /dev/null +++ b/tests/optimise/test_fixed_point.py @@ -0,0 +1,84 @@ +import functools as ft + +import jax +import numpy as np +from numpy.testing import assert_array_almost_equal + +import pytensor.tensor as pt +from pytensor.optimise.fixed_point import fixed_point_solver, fwd_solver, newton_solver + + +jax.config.update("jax_enable_x64", True) + + +def jax_newton_solver(f, z_init): + def f_root(z): + return f(z) - z + + def g(z): + return z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) + + return jax_fwd_solver(g, z_init) + + +def jax_fwd_solver(f, z_init, tol=1e-5): + z_prev, z = z_init, f(z_init) + while jax.numpy.linalg.norm(z_prev - z) > tol: + z_prev, z = z, f(z) + return z + + +def test_fixed_point_forward(): + def g(x, W, b): + return pt.tanh(pt.dot(W, x) + b) + + def _jax_g(x, W, b): + return jax.numpy.tanh(jax.numpy.dot(W, x) + b) + + ndim = 10 + W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim) + b = jax.random.normal(jax.random.PRNGKey(1), (ndim,)) + + W, b = np.asarray(W), np.asarray(b) + + jax_g = ft.partial(_jax_g, W=W, b=b) + + jax_solution = jax_fwd_solver(jax_g, jax.numpy.zeros_like(b)) + pytensor_solution, _ = fixed_point_solver( + g, + fwd_solver, + pt.zeros_like(b), + W, + b, + ) + assert_array_almost_equal(jax_solution, pytensor_solution.eval(), decimal=5) + + +def test_fixed_point_newton(): + def g(x, W, b): + return pt.tanh(pt.dot(W, x) + b) + + def _jax_g(x, W, b): + return jax.numpy.tanh(jax.numpy.dot(W, x) + b) + + ndim = 10 + W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim) + b = jax.random.normal(jax.random.PRNGKey(1), (ndim,)) + + W, b = np.asarray(W), np.asarray(b) + + jax_g = ft.partial(_jax_g, W=W, b=b) + + jax_solution = jax_newton_solver(jax_g, jax.numpy.zeros_like(b)) + pytensor_solution, _ = fixed_point_solver( + g, + newton_solver, + pt.zeros_like(b), + W, + b, + ) + assert_array_almost_equal(jax_solution, pytensor_solution.eval(), decimal=5) + + +# TODO: test the grad is the same as naive grad from propagating through each step of the solver (`pt.grad`) +# and adjoint implicit function theorem rewritten grad From ff5c32f46ce3ad4fd1a463d7965de48f395eeb91 Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 18:58:23 +0100 Subject: [PATCH 2/7] Further information on where test is from --- tests/optimise/test_fixed_point.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py index 90860ae15b..fc1bf3af34 100644 --- a/tests/optimise/test_fixed_point.py +++ b/tests/optimise/test_fixed_point.py @@ -29,6 +29,8 @@ def jax_fwd_solver(f, z_init, tol=1e-5): def test_fixed_point_forward(): + """Test taken from the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/).""" + def g(x, W, b): return pt.tanh(pt.dot(W, x) + b) @@ -82,3 +84,5 @@ def _jax_g(x, W, b): # TODO: test the grad is the same as naive grad from propagating through each step of the solver (`pt.grad`) # and adjoint implicit function theorem rewritten grad +# see the [notes](https://theorashid.github.io/notes/fixed-point-iteration +# and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/) From 01858c41a78ca8a49b609eb1867f6896172bb0a2 Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 18:59:15 +0100 Subject: [PATCH 3/7] remove commented test case --- pytensor/optimise/fixed_point.py | 63 +------------------------------- 1 file changed, 1 insertion(+), 62 deletions(-) diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py index e9732e1be6..5b00b2f2c3 100644 --- a/pytensor/optimise/fixed_point.py +++ b/pytensor/optimise/fixed_point.py @@ -26,7 +26,7 @@ def newton_solver(x_prev, *args, func, tol): # 1D: x - f(x) / f'(x) # general: x - J^-1 f(x) - # grad = pt.linalg.solve(jac, f_root, assume_a="sym") + # TODO: consider `grad = pt.linalg.solve(jac, f_root, assume_a="sym")`` grad = pt.linalg.solve(jac, f_root) x = x_prev - grad @@ -71,64 +71,3 @@ def _scan_step(x, *args, func, solver, tol): x = x_sequence[-1] return x, x_sequence.shape[0] - - -# %% -# x_star, n_steps = fixed_point_solver( -# g, -# fwd_solver, -# pt.zeros_like(b), -# W, b, -# ) -# print(x_star.eval(), n_steps.eval()) - -# %% -# x_star, n_steps = fixed_point_solver( -# g, -# newton_solver, -# pt.zeros_like(b), -# W, b, -# max_n_steps=10, -# ) -# print(x_star.eval(), n_steps.eval()) - - -# def _newton_solver(x_prev, *args, func, tol): -# f_root = lambda x: func(x) - x -# g = lambda x: x - pt.linalg.solve(pt.jacobian(f_root)(x), f_root(x)) -# return fwd_solver(g, x) - -# # %% -# def jax_newton_solver(f, z_init): -# f_root = lambda z: f(z) - z -# grad = jax.numpy.linalg.solve(jax.jacobian(f_root)(z_init), f_root(z_init)) -# # print(np.linalg.solve(grad, f_root_z)) -# # print(sp.linalg.solve(grad, f_root_z)) -# # print(jax.numpy.linalg.solve(grad, f_root_z)) -# # print(pt.linalg.solve(grad, f_root_z).eval()) -# g = lambda z: z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) -# return jax_fwd_solver(g, z_init) -# # return grad - -# def jax_fwd_solver(f, z_init): -# z_prev, z = z_init, f(z_init) -# i = 1 -# while jax.numpy.linalg.norm(z_prev - z) > 1e-5: -# z_prev, z = z, f(z) -# i += 1 -# print(i) -# return z - -# def _jax_g(x, W, b): -# return jax.numpy.tanh(jax.numpy.dot(W, x) + b) - - -# jax_g = partial(_jax_g, W=W, b=b) - -# print(jax_newton_solver(jax_g, jax.numpy.zeros_like(b))) - -# Array([-0.02879991, -0.8708013 , -1.4001148 , -0.1013868 , -0.641474 , -# -0.7552165 , 0.62554246, 0.9438805 , -0.05192749, 1.430574 ], dtype=float32) - - -# %% From c2f2fce112cd1eb373317b7639b299cc5d09c86f Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 19:04:55 +0100 Subject: [PATCH 4/7] remove debug print statements from scan --- pytensor/optimise/fixed_point.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py index 5b00b2f2c3..2befb8c502 100644 --- a/pytensor/optimise/fixed_point.py +++ b/pytensor/optimise/fixed_point.py @@ -44,12 +44,9 @@ def fixed_point_solver( tol: float = 1e-5, ): args = [pt.as_tensor(arg) for arg in args] - print(len(args)) def _scan_step(x, *args, func, solver, tol): - print(x.type) x, is_converged = solver(x, *args, func=func, tol=tol) - print(x.type) return x, until(is_converged) partial_step = partial( @@ -59,7 +56,7 @@ def _scan_step(x, *args, func, solver, tol): tol=tol, ) - x_sequence, updates = pytensor.scan( + outputs, updates = pytensor.scan( partial_step, outputs_info=[x0], non_sequences=list(args), @@ -67,7 +64,7 @@ def _scan_step(x, *args, func, solver, tol): strict=True, ) + x_trace = outputs assert not updates - x = x_sequence[-1] - return x, x_sequence.shape[0] + return x_trace[-1], x_trace.shape[0] From 0121ec3e315608f240d4c0ce13246a0265e96af4 Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 19:11:28 +0100 Subject: [PATCH 5/7] use trace to work out n_steps (see https://github.com/aseyboldt/pytensor/blob/newton/experiment-newton.ipynb) --- pytensor/optimise/fixed_point.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py index 2befb8c502..736fe702d1 100644 --- a/pytensor/optimise/fixed_point.py +++ b/pytensor/optimise/fixed_point.py @@ -45,9 +45,9 @@ def fixed_point_solver( ): args = [pt.as_tensor(arg) for arg in args] - def _scan_step(x, *args, func, solver, tol): + def _scan_step(x, n_steps, *args, func, solver, tol): x, is_converged = solver(x, *args, func=func, tol=tol) - return x, until(is_converged) + return (x, n_steps + 1), until(is_converged) partial_step = partial( _scan_step, @@ -58,13 +58,13 @@ def _scan_step(x, *args, func, solver, tol): outputs, updates = pytensor.scan( partial_step, - outputs_info=[x0], + outputs_info=[x0, pt.constant(0, dtype="int64")], non_sequences=list(args), n_steps=max_iter, strict=True, ) - x_trace = outputs + x_trace, n_steps_trace = outputs assert not updates - return x_trace[-1], x_trace.shape[0] + return x_trace[-1], n_steps_trace[-1] From a49ade3ec4b27672066cd5c03e9527fccc775726 Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Mon, 19 Aug 2024 19:14:49 +0100 Subject: [PATCH 6/7] add TODO refactor newton solver to use `fwd_solver` --- pytensor/optimise/fixed_point.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py index 736fe702d1..d23a4d60d2 100644 --- a/pytensor/optimise/fixed_point.py +++ b/pytensor/optimise/fixed_point.py @@ -30,6 +30,7 @@ def newton_solver(x_prev, *args, func, tol): grad = pt.linalg.solve(jac, f_root) x = x_prev - grad + # TODO: consider if this can all be done as a single call to `fwd_solver` as in the jax test case is_converged = _check_convergence(x - x_prev, tol) return x, is_converged From 57aaf8c1282e1f33d6256b123fb23085f2c1300a Mon Sep 17 00:00:00 2001 From: theorashid <theoaorashid@gmail.com> Date: Wed, 11 Sep 2024 13:29:52 +0100 Subject: [PATCH 7/7] add notes so hackers can see --- pytensor/optimise/fixed_point.py | 5 +- tests/optimise/test_fixed_point.py | 95 ++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py index d23a4d60d2..370e04ad83 100644 --- a/pytensor/optimise/fixed_point.py +++ b/pytensor/optimise/fixed_point.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from functools import partial import pytensor @@ -37,8 +38,8 @@ def newton_solver(x_prev, *args, func, tol): def fixed_point_solver( - f: callable, - solver: callable, + f: Callable, + solver: Callable, x0: pt.TensorVariable, *args: tuple[pt.Variable, ...], max_iter: int = 1000, diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py index fc1bf3af34..2e9cd428ee 100644 --- a/tests/optimise/test_fixed_point.py +++ b/tests/optimise/test_fixed_point.py @@ -1,5 +1,3 @@ -import functools as ft - import jax import numpy as np from numpy.testing import assert_array_almost_equal @@ -11,21 +9,26 @@ jax.config.update("jax_enable_x64", True) -def jax_newton_solver(f, z_init): - def f_root(z): - return f(z) - z +def jax_newton_solver(f, x0): + def f_root(x): + return f(x) - x + + def g(x): + return x - jax.numpy.linalg.solve(jax.jacobian(f_root)(x), f_root(x)) + + return jax_fwd_solver(g, x0) - def g(z): - return z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) - return jax_fwd_solver(g, z_init) +def jax_fwd_solver(f, x0, tol=1e-5): + x_prev, x = x0, f(x0) + while jax.numpy.linalg.norm(x_prev - x) > tol: + x_prev, x = x, f(x) + return x -def jax_fwd_solver(f, z_init, tol=1e-5): - z_prev, z = z_init, f(z_init) - while jax.numpy.linalg.norm(z_prev - z) > tol: - z_prev, z = z, f(z) - return z +def jax_fixed_point_solver(solver, f, params, x0, **solver_kwargs): + x_star = solver(lambda x: f(x, *params), x0=x0, **solver_kwargs) + return x_star def test_fixed_point_forward(): @@ -34,7 +37,7 @@ def test_fixed_point_forward(): def g(x, W, b): return pt.tanh(pt.dot(W, x) + b) - def _jax_g(x, W, b): + def jax_g(x, W, b): return jax.numpy.tanh(jax.numpy.dot(W, x) + b) ndim = 10 @@ -43,9 +46,13 @@ def _jax_g(x, W, b): W, b = np.asarray(W), np.asarray(b) - jax_g = ft.partial(_jax_g, W=W, b=b) + jax_solution = jax_fixed_point_solver( + jax_fwd_solver, + jax_g, + (W, b), + x0=jax.numpy.zeros_like(b), + ) - jax_solution = jax_fwd_solver(jax_g, jax.numpy.zeros_like(b)) pytensor_solution, _ = fixed_point_solver( g, fwd_solver, @@ -60,7 +67,7 @@ def test_fixed_point_newton(): def g(x, W, b): return pt.tanh(pt.dot(W, x) + b) - def _jax_g(x, W, b): + def jax_g(x, W, b): return jax.numpy.tanh(jax.numpy.dot(W, x) + b) ndim = 10 @@ -69,9 +76,13 @@ def _jax_g(x, W, b): W, b = np.asarray(W), np.asarray(b) - jax_g = ft.partial(_jax_g, W=W, b=b) + jax_solution = jax_fixed_point_solver( + jax_newton_solver, + jax_g, + (W, b), + x0=jax.numpy.zeros_like(b), + ) - jax_solution = jax_newton_solver(jax_g, jax.numpy.zeros_like(b)) pytensor_solution, _ = fixed_point_solver( g, newton_solver, @@ -86,3 +97,49 @@ def _jax_g(x, W, b): # and adjoint implicit function theorem rewritten grad # see the [notes](https://theorashid.github.io/notes/fixed-point-iteration # and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/) + +# %% +# import jax +# import numpy as np + +# def grad_test_fixed_point_forward(): +# def jax_g(x, W, b): +# return jax.numpy.tanh(jax.numpy.dot(W, x) + b) + +# ndim = 10 +# W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim) +# b = jax.random.normal(jax.random.PRNGKey(1), (ndim,)) + +# W, b = np.asarray(W), np.asarray(b) # params + +# # gradient of the sum of the outputs with respect to the parameter matrix +# jax_grad = jax.grad( +# lambda W: jax_fixed_point_solver( +# jax_fwd_solver, +# jax_g, +# (W, b), # wrt W +# x0=jax.numpy.zeros_like(b), +# ).sum() +# )(W) +# print(jax_grad[0]) + +# grad_test_fixed_point_forward() + +# # params -> W +# # z -> x +# # x -> b +# # f = lambda W, b, x: jnp.tanh(jnp.dot(W, x) + b) +# # x_star = solver(lambda x: f(params, b, x), x_init=jnp.zeros_like(b)) +# # x_star = fixed_point_layer(fwd_solver, f, W, b) +# # g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, b).sum())(W) +# %% +# def implicit_gradients_vjp(solver, f, res, x_soln): +# params, x, x_star = res +# # find adjoint u^T via solver +# # u^T = w^T + u^T \delta_{x_star} f(x_star, params) +# _, vjp_x = jax.vjp(lambda : f(x, *params), x_star) # diff wrt x +# _, vjp_par = jax.vjp(lambda params: f(x, *params), *params) # diff wrt params +# u = solver(lambda u: vjp_x(u)[0] + x_soln, x0=jax.numpy.zeros_like(x_soln)) + +# # then compute vjp u^T \delta_{params} f(x_star, params) +# return vjp_par(u)