From 3d006ff67269eeccbdf49dc04002ac433cc550ec Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 18:54:08 +0100
Subject: [PATCH 1/7] feat: Add fixed point solver for optimization, with tests
 comparing to jax

Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
Co-authored-by: aseyboldt <aseyboldt@users.noreply.github.com>
Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com>
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
---
 pytensor/optimise/__init__.py      |   0
 pytensor/optimise/fixed_point.py   | 134 +++++++++++++++++++++++++++++
 tests/optimise/__init__.py         |   0
 tests/optimise/test_fixed_point.py |  84 ++++++++++++++++++
 4 files changed, 218 insertions(+)
 create mode 100644 pytensor/optimise/__init__.py
 create mode 100644 pytensor/optimise/fixed_point.py
 create mode 100644 tests/optimise/__init__.py
 create mode 100644 tests/optimise/test_fixed_point.py

diff --git a/pytensor/optimise/__init__.py b/pytensor/optimise/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
new file mode 100644
index 0000000000..e9732e1be6
--- /dev/null
+++ b/pytensor/optimise/fixed_point.py
@@ -0,0 +1,134 @@
+from functools import partial
+
+import pytensor
+import pytensor.tensor as pt
+from pytensor.scan.utils import until
+
+
+def _check_convergence(f_x, tol):
+    # TODO What convergence criterion? Norm of grad etc...
+    converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol)
+    return converged
+
+
+def fwd_solver(x_prev, *args, func, tol):
+    x = func(x_prev, *args)
+    is_converged = _check_convergence(x - x_prev, tol)
+    return x, is_converged
+
+
+def newton_solver(x_prev, *args, func, tol):
+    f_root = func(x_prev, *args) - x_prev
+    jac = pt.jacobian(f_root, x_prev)
+
+    # TODO It would be nice to return the factored matrix for the pullback
+    # TODO Handle errors of the factorization
+    # 1D: x - f(x) / f'(x)
+    # general: x - J^-1 f(x)
+
+    # grad = pt.linalg.solve(jac, f_root, assume_a="sym")
+    grad = pt.linalg.solve(jac, f_root)
+    x = x_prev - grad
+
+    is_converged = _check_convergence(x - x_prev, tol)
+
+    return x, is_converged
+
+
+def fixed_point_solver(
+    f: callable,
+    solver: callable,
+    x0: pt.TensorVariable,
+    *args: tuple[pt.Variable, ...],
+    max_iter: int = 1000,
+    tol: float = 1e-5,
+):
+    args = [pt.as_tensor(arg) for arg in args]
+    print(len(args))
+
+    def _scan_step(x, *args, func, solver, tol):
+        print(x.type)
+        x, is_converged = solver(x, *args, func=func, tol=tol)
+        print(x.type)
+        return x, until(is_converged)
+
+    partial_step = partial(
+        _scan_step,
+        func=f,
+        solver=solver,
+        tol=tol,
+    )
+
+    x_sequence, updates = pytensor.scan(
+        partial_step,
+        outputs_info=[x0],
+        non_sequences=list(args),
+        n_steps=max_iter,
+        strict=True,
+    )
+
+    assert not updates
+
+    x = x_sequence[-1]
+    return x, x_sequence.shape[0]
+
+
+# %%
+# x_star, n_steps = fixed_point_solver(
+#     g,
+#     fwd_solver,
+#     pt.zeros_like(b),
+#     W, b,
+# )
+# print(x_star.eval(), n_steps.eval())
+
+# %%
+# x_star, n_steps = fixed_point_solver(
+#     g,
+#     newton_solver,
+#     pt.zeros_like(b),
+#     W, b,
+#     max_n_steps=10,
+# )
+# print(x_star.eval(), n_steps.eval())
+
+
+# def _newton_solver(x_prev, *args, func, tol):
+#     f_root = lambda x: func(x) - x
+#     g = lambda x: x - pt.linalg.solve(pt.jacobian(f_root)(x), f_root(x))
+#     return fwd_solver(g, x)
+
+# # %%
+# def jax_newton_solver(f, z_init):
+#     f_root = lambda z: f(z) - z
+#     grad = jax.numpy.linalg.solve(jax.jacobian(f_root)(z_init), f_root(z_init))
+#     # print(np.linalg.solve(grad, f_root_z))
+#     # print(sp.linalg.solve(grad, f_root_z))
+#     # print(jax.numpy.linalg.solve(grad, f_root_z))
+#     # print(pt.linalg.solve(grad, f_root_z).eval())
+#     g = lambda z: z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
+#     return jax_fwd_solver(g, z_init)
+#     # return grad
+
+# def jax_fwd_solver(f, z_init):
+#     z_prev, z = z_init, f(z_init)
+#     i = 1
+#     while jax.numpy.linalg.norm(z_prev - z) > 1e-5:
+#         z_prev, z = z, f(z)
+#         i += 1
+#     print(i)
+#     return z
+
+# def _jax_g(x, W, b):
+#     return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
+
+
+# jax_g = partial(_jax_g, W=W, b=b)
+
+# print(jax_newton_solver(jax_g, jax.numpy.zeros_like(b)))
+
+# Array([-0.02879991, -0.8708013 , -1.4001148 , -0.1013868 , -0.641474  ,
+#        -0.7552165 ,  0.62554246,  0.9438805 , -0.05192749,  1.430574  ],      dtype=float32)
+
+
+# %%
diff --git a/tests/optimise/__init__.py b/tests/optimise/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py
new file mode 100644
index 0000000000..90860ae15b
--- /dev/null
+++ b/tests/optimise/test_fixed_point.py
@@ -0,0 +1,84 @@
+import functools as ft
+
+import jax
+import numpy as np
+from numpy.testing import assert_array_almost_equal
+
+import pytensor.tensor as pt
+from pytensor.optimise.fixed_point import fixed_point_solver, fwd_solver, newton_solver
+
+
+jax.config.update("jax_enable_x64", True)
+
+
+def jax_newton_solver(f, z_init):
+    def f_root(z):
+        return f(z) - z
+
+    def g(z):
+        return z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
+
+    return jax_fwd_solver(g, z_init)
+
+
+def jax_fwd_solver(f, z_init, tol=1e-5):
+    z_prev, z = z_init, f(z_init)
+    while jax.numpy.linalg.norm(z_prev - z) > tol:
+        z_prev, z = z, f(z)
+    return z
+
+
+def test_fixed_point_forward():
+    def g(x, W, b):
+        return pt.tanh(pt.dot(W, x) + b)
+
+    def _jax_g(x, W, b):
+        return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
+
+    ndim = 10
+    W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim)
+    b = jax.random.normal(jax.random.PRNGKey(1), (ndim,))
+
+    W, b = np.asarray(W), np.asarray(b)
+
+    jax_g = ft.partial(_jax_g, W=W, b=b)
+
+    jax_solution = jax_fwd_solver(jax_g, jax.numpy.zeros_like(b))
+    pytensor_solution, _ = fixed_point_solver(
+        g,
+        fwd_solver,
+        pt.zeros_like(b),
+        W,
+        b,
+    )
+    assert_array_almost_equal(jax_solution, pytensor_solution.eval(), decimal=5)
+
+
+def test_fixed_point_newton():
+    def g(x, W, b):
+        return pt.tanh(pt.dot(W, x) + b)
+
+    def _jax_g(x, W, b):
+        return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
+
+    ndim = 10
+    W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim)
+    b = jax.random.normal(jax.random.PRNGKey(1), (ndim,))
+
+    W, b = np.asarray(W), np.asarray(b)
+
+    jax_g = ft.partial(_jax_g, W=W, b=b)
+
+    jax_solution = jax_newton_solver(jax_g, jax.numpy.zeros_like(b))
+    pytensor_solution, _ = fixed_point_solver(
+        g,
+        newton_solver,
+        pt.zeros_like(b),
+        W,
+        b,
+    )
+    assert_array_almost_equal(jax_solution, pytensor_solution.eval(), decimal=5)
+
+
+# TODO: test the grad is the same as naive grad from propagating through each step of the solver (`pt.grad`)
+# and adjoint implicit function theorem rewritten grad

From ff5c32f46ce3ad4fd1a463d7965de48f395eeb91 Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 18:58:23 +0100
Subject: [PATCH 2/7] Further information on where test is from

---
 tests/optimise/test_fixed_point.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py
index 90860ae15b..fc1bf3af34 100644
--- a/tests/optimise/test_fixed_point.py
+++ b/tests/optimise/test_fixed_point.py
@@ -29,6 +29,8 @@ def jax_fwd_solver(f, z_init, tol=1e-5):
 
 
 def test_fixed_point_forward():
+    """Test taken from the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/)."""
+
     def g(x, W, b):
         return pt.tanh(pt.dot(W, x) + b)
 
@@ -82,3 +84,5 @@ def _jax_g(x, W, b):
 
 # TODO: test the grad is the same as naive grad from propagating through each step of the solver (`pt.grad`)
 # and adjoint implicit function theorem rewritten grad
+# see the [notes](https://theorashid.github.io/notes/fixed-point-iteration
+# and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/)

From 01858c41a78ca8a49b609eb1867f6896172bb0a2 Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 18:59:15 +0100
Subject: [PATCH 3/7] remove commented test case

---
 pytensor/optimise/fixed_point.py | 63 +-------------------------------
 1 file changed, 1 insertion(+), 62 deletions(-)

diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
index e9732e1be6..5b00b2f2c3 100644
--- a/pytensor/optimise/fixed_point.py
+++ b/pytensor/optimise/fixed_point.py
@@ -26,7 +26,7 @@ def newton_solver(x_prev, *args, func, tol):
     # 1D: x - f(x) / f'(x)
     # general: x - J^-1 f(x)
 
-    # grad = pt.linalg.solve(jac, f_root, assume_a="sym")
+    # TODO: consider `grad = pt.linalg.solve(jac, f_root, assume_a="sym")``
     grad = pt.linalg.solve(jac, f_root)
     x = x_prev - grad
 
@@ -71,64 +71,3 @@ def _scan_step(x, *args, func, solver, tol):
 
     x = x_sequence[-1]
     return x, x_sequence.shape[0]
-
-
-# %%
-# x_star, n_steps = fixed_point_solver(
-#     g,
-#     fwd_solver,
-#     pt.zeros_like(b),
-#     W, b,
-# )
-# print(x_star.eval(), n_steps.eval())
-
-# %%
-# x_star, n_steps = fixed_point_solver(
-#     g,
-#     newton_solver,
-#     pt.zeros_like(b),
-#     W, b,
-#     max_n_steps=10,
-# )
-# print(x_star.eval(), n_steps.eval())
-
-
-# def _newton_solver(x_prev, *args, func, tol):
-#     f_root = lambda x: func(x) - x
-#     g = lambda x: x - pt.linalg.solve(pt.jacobian(f_root)(x), f_root(x))
-#     return fwd_solver(g, x)
-
-# # %%
-# def jax_newton_solver(f, z_init):
-#     f_root = lambda z: f(z) - z
-#     grad = jax.numpy.linalg.solve(jax.jacobian(f_root)(z_init), f_root(z_init))
-#     # print(np.linalg.solve(grad, f_root_z))
-#     # print(sp.linalg.solve(grad, f_root_z))
-#     # print(jax.numpy.linalg.solve(grad, f_root_z))
-#     # print(pt.linalg.solve(grad, f_root_z).eval())
-#     g = lambda z: z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
-#     return jax_fwd_solver(g, z_init)
-#     # return grad
-
-# def jax_fwd_solver(f, z_init):
-#     z_prev, z = z_init, f(z_init)
-#     i = 1
-#     while jax.numpy.linalg.norm(z_prev - z) > 1e-5:
-#         z_prev, z = z, f(z)
-#         i += 1
-#     print(i)
-#     return z
-
-# def _jax_g(x, W, b):
-#     return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
-
-
-# jax_g = partial(_jax_g, W=W, b=b)
-
-# print(jax_newton_solver(jax_g, jax.numpy.zeros_like(b)))
-
-# Array([-0.02879991, -0.8708013 , -1.4001148 , -0.1013868 , -0.641474  ,
-#        -0.7552165 ,  0.62554246,  0.9438805 , -0.05192749,  1.430574  ],      dtype=float32)
-
-
-# %%

From c2f2fce112cd1eb373317b7639b299cc5d09c86f Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 19:04:55 +0100
Subject: [PATCH 4/7] remove debug print statements from scan

---
 pytensor/optimise/fixed_point.py | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
index 5b00b2f2c3..2befb8c502 100644
--- a/pytensor/optimise/fixed_point.py
+++ b/pytensor/optimise/fixed_point.py
@@ -44,12 +44,9 @@ def fixed_point_solver(
     tol: float = 1e-5,
 ):
     args = [pt.as_tensor(arg) for arg in args]
-    print(len(args))
 
     def _scan_step(x, *args, func, solver, tol):
-        print(x.type)
         x, is_converged = solver(x, *args, func=func, tol=tol)
-        print(x.type)
         return x, until(is_converged)
 
     partial_step = partial(
@@ -59,7 +56,7 @@ def _scan_step(x, *args, func, solver, tol):
         tol=tol,
     )
 
-    x_sequence, updates = pytensor.scan(
+    outputs, updates = pytensor.scan(
         partial_step,
         outputs_info=[x0],
         non_sequences=list(args),
@@ -67,7 +64,7 @@ def _scan_step(x, *args, func, solver, tol):
         strict=True,
     )
 
+    x_trace = outputs
     assert not updates
 
-    x = x_sequence[-1]
-    return x, x_sequence.shape[0]
+    return x_trace[-1], x_trace.shape[0]

From 0121ec3e315608f240d4c0ce13246a0265e96af4 Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 19:11:28 +0100
Subject: [PATCH 5/7] use trace to work out n_steps (see
 https://github.com/aseyboldt/pytensor/blob/newton/experiment-newton.ipynb)

---
 pytensor/optimise/fixed_point.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
index 2befb8c502..736fe702d1 100644
--- a/pytensor/optimise/fixed_point.py
+++ b/pytensor/optimise/fixed_point.py
@@ -45,9 +45,9 @@ def fixed_point_solver(
 ):
     args = [pt.as_tensor(arg) for arg in args]
 
-    def _scan_step(x, *args, func, solver, tol):
+    def _scan_step(x, n_steps, *args, func, solver, tol):
         x, is_converged = solver(x, *args, func=func, tol=tol)
-        return x, until(is_converged)
+        return (x, n_steps + 1), until(is_converged)
 
     partial_step = partial(
         _scan_step,
@@ -58,13 +58,13 @@ def _scan_step(x, *args, func, solver, tol):
 
     outputs, updates = pytensor.scan(
         partial_step,
-        outputs_info=[x0],
+        outputs_info=[x0, pt.constant(0, dtype="int64")],
         non_sequences=list(args),
         n_steps=max_iter,
         strict=True,
     )
 
-    x_trace = outputs
+    x_trace, n_steps_trace = outputs
     assert not updates
 
-    return x_trace[-1], x_trace.shape[0]
+    return x_trace[-1], n_steps_trace[-1]

From a49ade3ec4b27672066cd5c03e9527fccc775726 Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Mon, 19 Aug 2024 19:14:49 +0100
Subject: [PATCH 6/7] add TODO refactor newton solver to use `fwd_solver`

---
 pytensor/optimise/fixed_point.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
index 736fe702d1..d23a4d60d2 100644
--- a/pytensor/optimise/fixed_point.py
+++ b/pytensor/optimise/fixed_point.py
@@ -30,6 +30,7 @@ def newton_solver(x_prev, *args, func, tol):
     grad = pt.linalg.solve(jac, f_root)
     x = x_prev - grad
 
+    # TODO: consider if this can all be done as a single call to `fwd_solver` as in the jax test case
     is_converged = _check_convergence(x - x_prev, tol)
 
     return x, is_converged

From 57aaf8c1282e1f33d6256b123fb23085f2c1300a Mon Sep 17 00:00:00 2001
From: theorashid <theoaorashid@gmail.com>
Date: Wed, 11 Sep 2024 13:29:52 +0100
Subject: [PATCH 7/7] add notes so hackers can see

---
 pytensor/optimise/fixed_point.py   |  5 +-
 tests/optimise/test_fixed_point.py | 95 ++++++++++++++++++++++++------
 2 files changed, 79 insertions(+), 21 deletions(-)

diff --git a/pytensor/optimise/fixed_point.py b/pytensor/optimise/fixed_point.py
index d23a4d60d2..370e04ad83 100644
--- a/pytensor/optimise/fixed_point.py
+++ b/pytensor/optimise/fixed_point.py
@@ -1,3 +1,4 @@
+from collections.abc import Callable
 from functools import partial
 
 import pytensor
@@ -37,8 +38,8 @@ def newton_solver(x_prev, *args, func, tol):
 
 
 def fixed_point_solver(
-    f: callable,
-    solver: callable,
+    f: Callable,
+    solver: Callable,
     x0: pt.TensorVariable,
     *args: tuple[pt.Variable, ...],
     max_iter: int = 1000,
diff --git a/tests/optimise/test_fixed_point.py b/tests/optimise/test_fixed_point.py
index fc1bf3af34..2e9cd428ee 100644
--- a/tests/optimise/test_fixed_point.py
+++ b/tests/optimise/test_fixed_point.py
@@ -1,5 +1,3 @@
-import functools as ft
-
 import jax
 import numpy as np
 from numpy.testing import assert_array_almost_equal
@@ -11,21 +9,26 @@
 jax.config.update("jax_enable_x64", True)
 
 
-def jax_newton_solver(f, z_init):
-    def f_root(z):
-        return f(z) - z
+def jax_newton_solver(f, x0):
+    def f_root(x):
+        return f(x) - x
+
+    def g(x):
+        return x - jax.numpy.linalg.solve(jax.jacobian(f_root)(x), f_root(x))
+
+    return jax_fwd_solver(g, x0)
 
-    def g(z):
-        return z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
 
-    return jax_fwd_solver(g, z_init)
+def jax_fwd_solver(f, x0, tol=1e-5):
+    x_prev, x = x0, f(x0)
+    while jax.numpy.linalg.norm(x_prev - x) > tol:
+        x_prev, x = x, f(x)
+    return x
 
 
-def jax_fwd_solver(f, z_init, tol=1e-5):
-    z_prev, z = z_init, f(z_init)
-    while jax.numpy.linalg.norm(z_prev - z) > tol:
-        z_prev, z = z, f(z)
-    return z
+def jax_fixed_point_solver(solver, f, params, x0, **solver_kwargs):
+    x_star = solver(lambda x: f(x, *params), x0=x0, **solver_kwargs)
+    return x_star
 
 
 def test_fixed_point_forward():
@@ -34,7 +37,7 @@ def test_fixed_point_forward():
     def g(x, W, b):
         return pt.tanh(pt.dot(W, x) + b)
 
-    def _jax_g(x, W, b):
+    def jax_g(x, W, b):
         return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
 
     ndim = 10
@@ -43,9 +46,13 @@ def _jax_g(x, W, b):
 
     W, b = np.asarray(W), np.asarray(b)
 
-    jax_g = ft.partial(_jax_g, W=W, b=b)
+    jax_solution = jax_fixed_point_solver(
+        jax_fwd_solver,
+        jax_g,
+        (W, b),
+        x0=jax.numpy.zeros_like(b),
+    )
 
-    jax_solution = jax_fwd_solver(jax_g, jax.numpy.zeros_like(b))
     pytensor_solution, _ = fixed_point_solver(
         g,
         fwd_solver,
@@ -60,7 +67,7 @@ def test_fixed_point_newton():
     def g(x, W, b):
         return pt.tanh(pt.dot(W, x) + b)
 
-    def _jax_g(x, W, b):
+    def jax_g(x, W, b):
         return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
 
     ndim = 10
@@ -69,9 +76,13 @@ def _jax_g(x, W, b):
 
     W, b = np.asarray(W), np.asarray(b)
 
-    jax_g = ft.partial(_jax_g, W=W, b=b)
+    jax_solution = jax_fixed_point_solver(
+        jax_newton_solver,
+        jax_g,
+        (W, b),
+        x0=jax.numpy.zeros_like(b),
+    )
 
-    jax_solution = jax_newton_solver(jax_g, jax.numpy.zeros_like(b))
     pytensor_solution, _ = fixed_point_solver(
         g,
         newton_solver,
@@ -86,3 +97,49 @@ def _jax_g(x, W, b):
 # and adjoint implicit function theorem rewritten grad
 # see the [notes](https://theorashid.github.io/notes/fixed-point-iteration
 # and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/)
+
+# %%
+# import jax
+# import numpy as np
+
+# def grad_test_fixed_point_forward():
+#     def jax_g(x, W, b):
+#         return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
+
+#     ndim = 10
+#     W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim)
+#     b = jax.random.normal(jax.random.PRNGKey(1), (ndim,))
+
+#     W, b = np.asarray(W), np.asarray(b)  # params
+
+#     # gradient of the sum of the outputs with respect to the parameter matrix
+#     jax_grad = jax.grad(
+#         lambda W: jax_fixed_point_solver(
+#             jax_fwd_solver,
+#             jax_g,
+#             (W, b),  # wrt W
+#             x0=jax.numpy.zeros_like(b),
+#         ).sum()
+#     )(W)
+#     print(jax_grad[0])
+
+# grad_test_fixed_point_forward()
+
+#     # params -> W
+#     # z -> x
+#     # x -> b
+#     # f = lambda W, b, x: jnp.tanh(jnp.dot(W, x) + b)
+#     # x_star = solver(lambda x: f(params, b, x), x_init=jnp.zeros_like(b))
+#     # x_star = fixed_point_layer(fwd_solver, f, W, b)
+#     # g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, b).sum())(W)
+# %%
+# def implicit_gradients_vjp(solver, f, res, x_soln):
+#     params, x, x_star = res
+#     # find adjoint u^T via solver
+#     # u^T = w^T + u^T \delta_{x_star} f(x_star, params)
+#     _, vjp_x = jax.vjp(lambda : f(x, *params), x_star)  # diff wrt x
+#     _, vjp_par = jax.vjp(lambda params: f(x, *params), *params)  # diff wrt params
+#     u = solver(lambda u: vjp_x(u)[0] + x_soln, x0=jax.numpy.zeros_like(x_soln))
+
+#     # then compute vjp u^T \delta_{params} f(x_star, params)
+#     return vjp_par(u)