Skip to content

Commit 5055262

Browse files
Add find_MAP with close JAX integration and fix bug with Laplace fit (#385)
* Add JAX-based `find_MAP` * add `better_optimize` to CI envs * Fix relative import * Remove `find_MAP` import from module-level `__init__.py` * Update docstring * Allow calling `find_MAP` inside model context without model argument * Required patched better_optimize * in-progress refactor * More refactor * Generalize code to use any pytensor backend * Reconcile the two laplace approximation functions * Use absolute import in doctest * Fix imports * Fix unrelated statespace test * - Rename argument `use_jax_gradients` -> `gradient_backend` - Rename function `laplace` -> `sample_laplace_posterior` * Fix typo introduced by rename refactor * use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization * Rename `test_jax_find_map.py` -> `test_find_map.py` * Improve docstring for `fit_laplace` * Update tests to match new signature * Update docstring
1 parent 40714de commit 5055262

File tree

8 files changed

+1178
-163
lines changed

8 files changed

+1178
-163
lines changed

conda-envs/environment-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ dependencies:
1313
- pymc>=5.17.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn
16+
- better_optimize>=0.0.10

conda-envs/windows-environment-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ dependencies:
1313
- pymc>=5.17.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn
16+
- better_optimize>=0.0.10

0 commit comments

Comments
 (0)