Skip to content

Commit b799547

Browse files
authored
Adding NUTS sampler from blackjax to sampling_jax (#5477)
* Adding NUTS sampler from blackjax to sampling_jax * Lint fixes * Refactor get_jaxified_logp * Install blackjax in workflows * Fix url * Simplify function * Add documentation * Add library versions * Move to more appropriate section * Fix docstrings * Add jax to dev environment * Jax doesn't work on windows, remove it * Exclude jax * Fix merge * Remove progress bar functionality for now * Pre-commit fix
1 parent 65dcb49 commit b799547

11 files changed

+294
-15
lines changed

Diff for: .github/workflows/jaxtests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ jobs:
7171
run: |
7272
conda activate pymc-test-py39
7373
pip install "numpyro>=0.8.0"
74+
pip install git+https://github.com/blackjax-devs/blackjax.git@main
7475
- name: Run tests
7576
run: |
7677
python -m pytest -vv --cov=pymc --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET

Diff for: conda-envs/environment-dev-py37.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- myst-nb
1718
- numpy>=1.15.0
1819
- numpydoc<1.2

Diff for: conda-envs/environment-dev-py38.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- myst-nb
1718
- numpy>=1.15.0
1819
- numpydoc<1.2

Diff for: conda-envs/environment-dev-py39.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- myst-nb
1718
- numpy>=1.15.0
1819
- numpydoc<1.2

Diff for: conda-envs/environment-test-py37.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- libblas=*=*mkl
1718
- mkl-service
1819
- numpy>=1.15.0

Diff for: conda-envs/environment-test-py38.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- libblas=*=*mkl
1718
- mkl-service
1819
- numpy>=1.15.0

Diff for: conda-envs/environment-test-py39.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- fastprogress>=0.2.0
1414
- h5py>=2.7
1515
- ipython>=7.16
16+
- jax
1617
- libblas=*=*mkl
1718
- mkl-service
1819
- numpy>=1.15.0

Diff for: docs/source/api/samplers.rst

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ This submodule contains functions for MCMC and forward sampling.
1313
sample_prior_predictive
1414
sample_posterior_predictive
1515
sample_posterior_predictive_w
16+
sampling_jax.sample_blackjax_nuts
17+
sampling_jax.sample_numpyro_nuts
1618
iter_sample
1719
init_nuts
1820
draw

Diff for: pymc/sampling_jax.py

+255-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import warnings
55

6+
from functools import partial
67
from typing import Callable, Dict, List, Optional, Sequence, Union
78

89
from pymc.initial_point import StartDict
@@ -26,6 +27,7 @@
2627
from aesara.link.jax.dispatch import jax_funcify
2728
from aesara.raise_op import Assert
2829
from aesara.tensor import TensorVariable
30+
from arviz.data.base import make_attrs
2931

3032
from pymc import Model, modelcontext
3133
from pymc.backends.arviz import find_observations
@@ -97,14 +99,14 @@ def get_jaxified_graph(
9799
return jax_funcify(fgraph)
98100

99101

100-
def get_jaxified_logp(model: Model) -> Callable:
101-
102-
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])
102+
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
103+
model_logpt = model.logpt()
104+
if not negative_logp:
105+
model_logpt = -model_logpt
106+
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
103107

104108
def logp_fn_wrap(x):
105-
# NumPyro expects a scalar potential with the opposite sign of model.logpt
106-
res = logp_fn(*x)[0]
107-
return -res
109+
return logp_fn(*x)[0]
108110

109111
return logp_fn_wrap
110112

@@ -177,6 +179,202 @@ def _get_batched_jittered_initial_points(
177179
return initial_points
178180

179181

182+
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
183+
def _blackjax_inference_loop(
184+
seed,
185+
init_position,
186+
logprob_fn,
187+
draws,
188+
tune,
189+
target_accept,
190+
algorithm=None,
191+
):
192+
import blackjax
193+
194+
if algorithm is None:
195+
algorithm = blackjax.nuts
196+
197+
adapt = blackjax.window_adaptation(
198+
algorithm=algorithm,
199+
logprob_fn=logprob_fn,
200+
num_steps=tune,
201+
target_acceptance_rate=target_accept,
202+
)
203+
last_state, kernel, _ = adapt.run(seed, init_position)
204+
205+
def inference_loop(rng_key, initial_state):
206+
def one_step(state, rng_key):
207+
state, info = kernel(rng_key, state)
208+
return state, (state, info)
209+
210+
keys = jax.random.split(rng_key, draws)
211+
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
212+
213+
return states, infos
214+
215+
return inference_loop(seed, last_state)
216+
217+
218+
def sample_blackjax_nuts(
219+
draws=1000,
220+
tune=1000,
221+
chains=4,
222+
target_accept=0.8,
223+
random_seed=10,
224+
initvals=None,
225+
model=None,
226+
var_names=None,
227+
keep_untransformed=False,
228+
chain_method="parallel",
229+
idata_kwargs=None,
230+
):
231+
"""
232+
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
233+
234+
Parameters
235+
----------
236+
draws : int, default 1000
237+
The number of samples to draw. The number of tuned samples are discarded by default.
238+
tune : int, default 1000
239+
Number of iterations to tune. Samplers adjust the step sizes, scalings or
240+
similar during tuning. Tuning samples will be drawn in addition to the number specified in
241+
the ``draws`` argument.
242+
chains : int, default 4
243+
The number of chains to sample.
244+
target_accept : float in [0, 1].
245+
The step size is tuned such that we approximate this acceptance rate. Higher values like
246+
0.9 or 0.95 often work better for problematic posteriors.
247+
random_seed : int, default 10
248+
Random seed used by the sampling steps.
249+
model : Model, optional
250+
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
251+
context, it defaults to that model, otherwise the model must be passed explicitly.
252+
var_names : iterable of str, optional
253+
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
254+
keep_untransformed : bool, default False
255+
Include untransformed variables in the posterior samples. Defaults to False.
256+
chain_method : str, default "parallel"
257+
Specify how samples should be drawn. The choices include "parallel", and "vectorized".
258+
idata_kwargs : dict, optional
259+
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
260+
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
261+
not be included in the returned object.
262+
263+
Returns
264+
-------
265+
InferenceData
266+
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
267+
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
268+
"""
269+
import blackjax
270+
271+
model = modelcontext(model)
272+
273+
if var_names is None:
274+
var_names = model.unobserved_value_vars
275+
276+
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
277+
278+
coords = {
279+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
280+
for cname, cvals in model.coords.items()
281+
if cvals is not None
282+
}
283+
284+
if hasattr(model, "RV_dims"):
285+
dims = {
286+
var_name: [dim for dim in dims if dim is not None]
287+
for var_name, dims in model.RV_dims.items()
288+
}
289+
else:
290+
dims = {}
291+
292+
tic1 = datetime.now()
293+
print("Compiling...", file=sys.stdout)
294+
295+
init_params = _get_batched_jittered_initial_points(
296+
model=model,
297+
chains=chains,
298+
initvals=initvals,
299+
random_seed=random_seed,
300+
)
301+
302+
if chains == 1:
303+
init_params = [np.stack(init_params)]
304+
init_params = [np.stack(init_state) for init_state in zip(*init_params)]
305+
306+
logprob_fn = get_jaxified_logp(model)
307+
308+
seed = jax.random.PRNGKey(random_seed)
309+
keys = jax.random.split(seed, chains)
310+
311+
get_posterior_samples = partial(
312+
_blackjax_inference_loop,
313+
logprob_fn=logprob_fn,
314+
tune=tune,
315+
draws=draws,
316+
target_accept=target_accept,
317+
)
318+
319+
tic2 = datetime.now()
320+
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
321+
322+
print("Sampling...", file=sys.stdout)
323+
324+
# Adapted from numpyro
325+
if chain_method == "parallel":
326+
map_fn = jax.pmap
327+
elif chain_method == "vectorized":
328+
map_fn = jax.vmap
329+
else:
330+
raise ValueError(
331+
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
332+
)
333+
334+
states, _ = map_fn(get_posterior_samples)(keys, init_params)
335+
raw_mcmc_samples = states.position
336+
337+
tic3 = datetime.now()
338+
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
339+
340+
print("Transforming variables...", file=sys.stdout)
341+
mcmc_samples = {}
342+
for v in vars_to_sample:
343+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
344+
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
345+
mcmc_samples[v.name] = result
346+
347+
tic4 = datetime.now()
348+
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
349+
350+
if idata_kwargs is None:
351+
idata_kwargs = {}
352+
else:
353+
idata_kwargs = idata_kwargs.copy()
354+
355+
if idata_kwargs.pop("log_likelihood", True):
356+
log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
357+
else:
358+
log_likelihood = None
359+
360+
attrs = {
361+
"sampling_time": (tic3 - tic2).total_seconds(),
362+
}
363+
364+
posterior = mcmc_samples
365+
az_trace = az.from_dict(
366+
posterior=posterior,
367+
log_likelihood=log_likelihood,
368+
observed_data=find_observations(model),
369+
coords=coords,
370+
dims=dims,
371+
attrs=make_attrs(attrs, library=blackjax),
372+
**idata_kwargs,
373+
)
374+
375+
return az_trace
376+
377+
180378
def sample_numpyro_nuts(
181379
draws: int = 1000,
182380
tune: int = 1000,
@@ -192,6 +390,51 @@ def sample_numpyro_nuts(
192390
idata_kwargs: Optional[Dict] = None,
193391
nuts_kwargs: Optional[Dict] = None,
194392
):
393+
"""
394+
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
395+
396+
Parameters
397+
----------
398+
draws : int, default 1000
399+
The number of samples to draw. The number of tuned samples are discarded by default.
400+
tune : int, default 1000
401+
Number of iterations to tune. Samplers adjust the step sizes, scalings or
402+
similar during tuning. Tuning samples will be drawn in addition to the number specified in
403+
the ``draws`` argument.
404+
chains : int, default 4
405+
The number of chains to sample.
406+
target_accept : float in [0, 1].
407+
The step size is tuned such that we approximate this acceptance rate. Higher values like
408+
0.9 or 0.95 often work better for problematic posteriors.
409+
random_seed : int, default 10
410+
Random seed used by the sampling steps.
411+
model : Model, optional
412+
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
413+
context, it defaults to that model, otherwise the model must be passed explicitly.
414+
var_names : iterable of str, optional
415+
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
416+
progress_bar : bool, default True
417+
Whether or not to display a progress bar in the command line. The bar shows the percentage
418+
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
419+
time until completion ("expected time of arrival"; ETA).
420+
keep_untransformed : bool, default False
421+
Include untransformed variables in the posterior samples. Defaults to False.
422+
chain_method : str, default "parallel"
423+
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
424+
idata_kwargs : dict, optional
425+
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
426+
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
427+
not be included in the returned object.
428+
429+
Returns
430+
-------
431+
InferenceData
432+
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
433+
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
434+
"""
435+
436+
import numpyro
437+
195438
from numpyro.infer import MCMC, NUTS
196439

197440
model = modelcontext(model)
@@ -228,7 +471,7 @@ def sample_numpyro_nuts(
228471
random_seed=random_seed,
229472
)
230473

231-
logp_fn = get_jaxified_logp(model)
474+
logp_fn = get_jaxified_logp(model, negative_logp=False)
232475

233476
if nuts_kwargs is None:
234477
nuts_kwargs = {}
@@ -298,6 +541,10 @@ def sample_numpyro_nuts(
298541
else:
299542
log_likelihood = None
300543

544+
attrs = {
545+
"sampling_time": (tic3 - tic2).total_seconds(),
546+
}
547+
301548
posterior = mcmc_samples
302549
az_trace = az.from_dict(
303550
posterior=posterior,
@@ -306,7 +553,7 @@ def sample_numpyro_nuts(
306553
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
307554
coords=coords,
308555
dims=dims,
309-
attrs={"sampling_time": (tic3 - tic2).total_seconds()},
556+
attrs=make_attrs(attrs, library=numpyro),
310557
**idata_kwargs,
311558
)
312559

0 commit comments

Comments
 (0)