From 0991ee8872c0d53e451a997ce19e5672997062c3 Mon Sep 17 00:00:00 2001 From: brentyi Date: Mon, 23 Dec 2024 02:20:05 -0800 Subject: [PATCH] Fix "auto" for forward-mode vs reverse-mode jacobians --- src/jaxls/_factor_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index aa41a00..78317b9 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -134,7 +134,7 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array: "reverse": jax.jacrev, "auto": jax.jacrev if factor.residual_dim < val_subset._get_tangent_dim() - else jax.jacrev, + else jax.jacfwd, }[factor.jac_mode] return jacfunc( # The residual function, with respect to to some local delta.