3
3
import sys
4
4
import warnings
5
5
6
+ from functools import partial
6
7
from typing import Callable , Dict , List , Optional , Sequence , Union
7
8
8
9
from pymc .initial_point import StartDict
26
27
from aesara .link .jax .dispatch import jax_funcify
27
28
from aesara .raise_op import Assert
28
29
from aesara .tensor import TensorVariable
30
+ from arviz .data .base import make_attrs
29
31
30
32
from pymc import Model , modelcontext
31
33
from pymc .backends .arviz import find_observations
@@ -97,14 +99,14 @@ def get_jaxified_graph(
97
99
return jax_funcify (fgraph )
98
100
99
101
100
- def get_jaxified_logp (model : Model ) -> Callable :
101
-
102
- logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model .logpt ()])
102
+ def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable :
103
+ model_logpt = model .logpt ()
104
+ if not negative_logp :
105
+ model_logpt = - model_logpt
106
+ logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logpt ])
103
107
104
108
def logp_fn_wrap (x ):
105
- # NumPyro expects a scalar potential with the opposite sign of model.logpt
106
- res = logp_fn (* x )[0 ]
107
- return - res
109
+ return logp_fn (* x )[0 ]
108
110
109
111
return logp_fn_wrap
110
112
@@ -177,6 +179,202 @@ def _get_batched_jittered_initial_points(
177
179
return initial_points
178
180
179
181
182
+ @partial (jax .jit , static_argnums = (2 , 3 , 4 , 5 , 6 ))
183
+ def _blackjax_inference_loop (
184
+ seed ,
185
+ init_position ,
186
+ logprob_fn ,
187
+ draws ,
188
+ tune ,
189
+ target_accept ,
190
+ algorithm = None ,
191
+ ):
192
+ import blackjax
193
+
194
+ if algorithm is None :
195
+ algorithm = blackjax .nuts
196
+
197
+ adapt = blackjax .window_adaptation (
198
+ algorithm = algorithm ,
199
+ logprob_fn = logprob_fn ,
200
+ num_steps = tune ,
201
+ target_acceptance_rate = target_accept ,
202
+ )
203
+ last_state , kernel , _ = adapt .run (seed , init_position )
204
+
205
+ def inference_loop (rng_key , initial_state ):
206
+ def one_step (state , rng_key ):
207
+ state , info = kernel (rng_key , state )
208
+ return state , (state , info )
209
+
210
+ keys = jax .random .split (rng_key , draws )
211
+ _ , (states , infos ) = jax .lax .scan (one_step , initial_state , keys )
212
+
213
+ return states , infos
214
+
215
+ return inference_loop (seed , last_state )
216
+
217
+
218
+ def sample_blackjax_nuts (
219
+ draws = 1000 ,
220
+ tune = 1000 ,
221
+ chains = 4 ,
222
+ target_accept = 0.8 ,
223
+ random_seed = 10 ,
224
+ initvals = None ,
225
+ model = None ,
226
+ var_names = None ,
227
+ keep_untransformed = False ,
228
+ chain_method = "parallel" ,
229
+ idata_kwargs = None ,
230
+ ):
231
+ """
232
+ Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
233
+
234
+ Parameters
235
+ ----------
236
+ draws : int, default 1000
237
+ The number of samples to draw. The number of tuned samples are discarded by default.
238
+ tune : int, default 1000
239
+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
240
+ similar during tuning. Tuning samples will be drawn in addition to the number specified in
241
+ the ``draws`` argument.
242
+ chains : int, default 4
243
+ The number of chains to sample.
244
+ target_accept : float in [0, 1].
245
+ The step size is tuned such that we approximate this acceptance rate. Higher values like
246
+ 0.9 or 0.95 often work better for problematic posteriors.
247
+ random_seed : int, default 10
248
+ Random seed used by the sampling steps.
249
+ model : Model, optional
250
+ Model to sample from. The model needs to have free random variables. When inside a ``with`` model
251
+ context, it defaults to that model, otherwise the model must be passed explicitly.
252
+ var_names : iterable of str, optional
253
+ Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
254
+ keep_untransformed : bool, default False
255
+ Include untransformed variables in the posterior samples. Defaults to False.
256
+ chain_method : str, default "parallel"
257
+ Specify how samples should be drawn. The choices include "parallel", and "vectorized".
258
+ idata_kwargs : dict, optional
259
+ Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
260
+ for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
261
+ not be included in the returned object.
262
+
263
+ Returns
264
+ -------
265
+ InferenceData
266
+ ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
267
+ pointwise log likeihood values (unless skipped with ``idata_kwargs``).
268
+ """
269
+ import blackjax
270
+
271
+ model = modelcontext (model )
272
+
273
+ if var_names is None :
274
+ var_names = model .unobserved_value_vars
275
+
276
+ vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
277
+
278
+ coords = {
279
+ cname : np .array (cvals ) if isinstance (cvals , tuple ) else cvals
280
+ for cname , cvals in model .coords .items ()
281
+ if cvals is not None
282
+ }
283
+
284
+ if hasattr (model , "RV_dims" ):
285
+ dims = {
286
+ var_name : [dim for dim in dims if dim is not None ]
287
+ for var_name , dims in model .RV_dims .items ()
288
+ }
289
+ else :
290
+ dims = {}
291
+
292
+ tic1 = datetime .now ()
293
+ print ("Compiling..." , file = sys .stdout )
294
+
295
+ init_params = _get_batched_jittered_initial_points (
296
+ model = model ,
297
+ chains = chains ,
298
+ initvals = initvals ,
299
+ random_seed = random_seed ,
300
+ )
301
+
302
+ if chains == 1 :
303
+ init_params = [np .stack (init_params )]
304
+ init_params = [np .stack (init_state ) for init_state in zip (* init_params )]
305
+
306
+ logprob_fn = get_jaxified_logp (model )
307
+
308
+ seed = jax .random .PRNGKey (random_seed )
309
+ keys = jax .random .split (seed , chains )
310
+
311
+ get_posterior_samples = partial (
312
+ _blackjax_inference_loop ,
313
+ logprob_fn = logprob_fn ,
314
+ tune = tune ,
315
+ draws = draws ,
316
+ target_accept = target_accept ,
317
+ )
318
+
319
+ tic2 = datetime .now ()
320
+ print ("Compilation time = " , tic2 - tic1 , file = sys .stdout )
321
+
322
+ print ("Sampling..." , file = sys .stdout )
323
+
324
+ # Adapted from numpyro
325
+ if chain_method == "parallel" :
326
+ map_fn = jax .pmap
327
+ elif chain_method == "vectorized" :
328
+ map_fn = jax .vmap
329
+ else :
330
+ raise ValueError (
331
+ "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
332
+ )
333
+
334
+ states , _ = map_fn (get_posterior_samples )(keys , init_params )
335
+ raw_mcmc_samples = states .position
336
+
337
+ tic3 = datetime .now ()
338
+ print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
339
+
340
+ print ("Transforming variables..." , file = sys .stdout )
341
+ mcmc_samples = {}
342
+ for v in vars_to_sample :
343
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [v ])
344
+ result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
345
+ mcmc_samples [v .name ] = result
346
+
347
+ tic4 = datetime .now ()
348
+ print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
349
+
350
+ if idata_kwargs is None :
351
+ idata_kwargs = {}
352
+ else :
353
+ idata_kwargs = idata_kwargs .copy ()
354
+
355
+ if idata_kwargs .pop ("log_likelihood" , True ):
356
+ log_likelihood = _get_log_likelihood (model , raw_mcmc_samples )
357
+ else :
358
+ log_likelihood = None
359
+
360
+ attrs = {
361
+ "sampling_time" : (tic3 - tic2 ).total_seconds (),
362
+ }
363
+
364
+ posterior = mcmc_samples
365
+ az_trace = az .from_dict (
366
+ posterior = posterior ,
367
+ log_likelihood = log_likelihood ,
368
+ observed_data = find_observations (model ),
369
+ coords = coords ,
370
+ dims = dims ,
371
+ attrs = make_attrs (attrs , library = blackjax ),
372
+ ** idata_kwargs ,
373
+ )
374
+
375
+ return az_trace
376
+
377
+
180
378
def sample_numpyro_nuts (
181
379
draws : int = 1000 ,
182
380
tune : int = 1000 ,
@@ -192,6 +390,51 @@ def sample_numpyro_nuts(
192
390
idata_kwargs : Optional [Dict ] = None ,
193
391
nuts_kwargs : Optional [Dict ] = None ,
194
392
):
393
+ """
394
+ Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
395
+
396
+ Parameters
397
+ ----------
398
+ draws : int, default 1000
399
+ The number of samples to draw. The number of tuned samples are discarded by default.
400
+ tune : int, default 1000
401
+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
402
+ similar during tuning. Tuning samples will be drawn in addition to the number specified in
403
+ the ``draws`` argument.
404
+ chains : int, default 4
405
+ The number of chains to sample.
406
+ target_accept : float in [0, 1].
407
+ The step size is tuned such that we approximate this acceptance rate. Higher values like
408
+ 0.9 or 0.95 often work better for problematic posteriors.
409
+ random_seed : int, default 10
410
+ Random seed used by the sampling steps.
411
+ model : Model, optional
412
+ Model to sample from. The model needs to have free random variables. When inside a ``with`` model
413
+ context, it defaults to that model, otherwise the model must be passed explicitly.
414
+ var_names : iterable of str, optional
415
+ Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
416
+ progress_bar : bool, default True
417
+ Whether or not to display a progress bar in the command line. The bar shows the percentage
418
+ of completion, the sampling speed in samples per second (SPS), and the estimated remaining
419
+ time until completion ("expected time of arrival"; ETA).
420
+ keep_untransformed : bool, default False
421
+ Include untransformed variables in the posterior samples. Defaults to False.
422
+ chain_method : str, default "parallel"
423
+ Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
424
+ idata_kwargs : dict, optional
425
+ Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
426
+ for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
427
+ not be included in the returned object.
428
+
429
+ Returns
430
+ -------
431
+ InferenceData
432
+ ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
433
+ pointwise log likeihood values (unless skipped with ``idata_kwargs``).
434
+ """
435
+
436
+ import numpyro
437
+
195
438
from numpyro .infer import MCMC , NUTS
196
439
197
440
model = modelcontext (model )
@@ -228,7 +471,7 @@ def sample_numpyro_nuts(
228
471
random_seed = random_seed ,
229
472
)
230
473
231
- logp_fn = get_jaxified_logp (model )
474
+ logp_fn = get_jaxified_logp (model , negative_logp = False )
232
475
233
476
if nuts_kwargs is None :
234
477
nuts_kwargs = {}
@@ -298,6 +541,10 @@ def sample_numpyro_nuts(
298
541
else :
299
542
log_likelihood = None
300
543
544
+ attrs = {
545
+ "sampling_time" : (tic3 - tic2 ).total_seconds (),
546
+ }
547
+
301
548
posterior = mcmc_samples
302
549
az_trace = az .from_dict (
303
550
posterior = posterior ,
@@ -306,7 +553,7 @@ def sample_numpyro_nuts(
306
553
sample_stats = _sample_stats_to_xarray (pmap_numpyro ),
307
554
coords = coords ,
308
555
dims = dims ,
309
- attrs = { "sampling_time" : ( tic3 - tic2 ). total_seconds ()} ,
556
+ attrs = make_attrs ( attrs , library = numpyro ) ,
310
557
** idata_kwargs ,
311
558
)
312
559
0 commit comments