diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index a740a75..04e0bdf 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -145,7 +145,9 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array: )(jnp.zeros((val_subset._get_tangent_dim(),))) # Compute Jacobian for each factor. - stacked_jac = jax.vmap(compute_jac_with_perturb)(factor) + stacked_jac = jax.lax.map( + compute_jac_with_perturb, factor, batch_size=factor.jac_batch_size + ) (num_factor,) = factor._get_batch_axes() assert stacked_jac.shape == ( num_factor, @@ -396,11 +398,20 @@ def _sort_key(x: Any) -> str: @jdc.pytree_dataclass class Factor[*Args]: - """A single cost in our factor graph.""" + """A single cost in our factor graph. Costs with the same pytree structure + will automatically be paralellized.""" compute_residual: jdc.Static[Callable[[VarValues, *Args], jax.Array]] args: tuple[*Args] jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]] = "auto" + """Depending on the function being differentiated, it may be faster to use + forward-mode or reverse-mode autodiff.""" + jac_batch_size: jdc.Static[int | None] = None + """Batch size for computing Jacobians that can be parallelized. Can be set + to make tradeoffs between runtime and memory usage. + + If None, we compute all Jacobians in parallel. If 1, we compute Jacobians + one at a time.""" @staticmethod @deprecated("Use Factor() directly instead of Factor.make()") diff --git a/src/jaxls/_variables.py b/src/jaxls/_variables.py index bdd46f3..01b1319 100644 --- a/src/jaxls/_variables.py +++ b/src/jaxls/_variables.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from dataclasses import dataclass from functools import total_ordering -from typing import Any, Callable, ClassVar, cast, overload +from typing import Any, Callable, ClassVar, Self, cast, overload import jax import jax_dataclasses as jdc @@ -83,6 +83,11 @@ def with_value(self, value: T) -> VarWithValue[T]: for `VarValues.make()`.""" return VarWithValue(self, value) + def __getitem__(self, index_or_slice: int | slice) -> Self: + """Shorthand for slicing the variable ID.""" + assert not isinstance(self.id, int) + return self.__class__(self.id[index_or_slice]) + @overload def __init_subclass__[T_]( cls, @@ -174,7 +179,10 @@ class VarValues: """Variable ID for each value, sorted in ascending order.""" def get_value[T](self, var: Var[T]) -> T: - """Get the value of a specific variable.""" + """Get the value of a specific variable or variables.""" + if not isinstance(var.id, int) and var.id.ndim > 0: + return jax.vmap(self.get_value)(var) + assert getattr(var.id, "shape", None) == () or isinstance(var.id, int) var_type = type(var) index = jnp.searchsorted(self.ids_from_type[var_type], var.id)