3535from pytensor .graph .fg import FunctionGraph
3636from pytensor .graph .replace import clone_replace
3737from pytensor .link .jax .dispatch import jax_funcify
38- from pytensor .raise_op import Assert
3938from pytensor .tensor import TensorVariable
4039from pytensor .tensor .random .type import RandomType
4140
4746)
4847from pymc .distributions .multivariate import PosDefMatrix
4948from pymc .initial_point import StartDict
50- from pymc .logprob .utils import CheckParameterValue
5149from pymc .sampling .mcmc import _init_jitter
5250from pymc .stats .convergence import log_warnings , run_convergence_checks
5351from pymc .util import (
7169)
7270
7371
74- @jax_funcify .register (Assert )
75- @jax_funcify .register (CheckParameterValue )
76- def jax_funcify_Assert (op , ** kwargs ):
77- # Jax does not allow assert whose values aren't known during JIT compilation
78- # within it's JIT-ed code. Hence we need to make a simple pass through
79- # version of the Assert Op.
80- # https://github.com/google/jax/issues/2273#issuecomment-589098722
81- def assert_fn (value , * inps ):
82- return value
83-
84- return assert_fn
85-
86-
8772@jax_funcify .register (PosDefMatrix )
8873def jax_funcify_PosDefMatrix (op , ** kwargs ):
8974 def posdefmatrix_fn (value , * inps ):
@@ -520,8 +505,6 @@ def sample_jax_nuts(
520505 keep_untransformed : bool = False ,
521506 chain_method : Literal ["parallel" , "vectorized" ] = "parallel" ,
522507 postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
523- postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
524- postprocessing_chunks = None ,
525508 idata_kwargs : dict | None = None ,
526509 compute_convergence_checks : bool = True ,
527510 nuts_sampler : Literal ["numpyro" , "blackjax" ],
@@ -593,25 +576,6 @@ def sample_jax_nuts(
593576 with their respective sample stats and pointwise log likeihood values (unless
594577 skipped with ``idata_kwargs``).
595578 """
596- if postprocessing_chunks is not None :
597- import warnings
598-
599- warnings .warn (
600- "postprocessing_chunks is deprecated due to being unstable, "
601- "using postprocessing_vectorize='scan' instead" ,
602- DeprecationWarning ,
603- )
604-
605- if postprocessing_vectorize is not None :
606- import warnings
607-
608- warnings .warn (
609- 'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.' ,
610- FutureWarning ,
611- )
612- else :
613- postprocessing_vectorize = "vmap"
614-
615579 model = modelcontext (model )
616580
617581 if var_names is not None :
@@ -674,7 +638,6 @@ def sample_jax_nuts(
674638 model ,
675639 raw_mcmc_samples ,
676640 backend = postprocessing_backend ,
677- postprocessing_vectorize = postprocessing_vectorize ,
678641 )
679642 else :
680643 log_likelihood = None
@@ -684,7 +647,6 @@ def sample_jax_nuts(
684647 jax_fn ,
685648 raw_mcmc_samples ,
686649 postprocessing_backend = postprocessing_backend ,
687- postprocessing_vectorize = postprocessing_vectorize ,
688650 donate_samples = True ,
689651 )
690652 del raw_mcmc_samples
@@ -704,8 +666,8 @@ def sample_jax_nuts(
704666 dims .update (idata_kwargs .pop ("dims" ))
705667
706668 # Use 'partial' to set default arguments before passing 'idata_kwargs'
707- to_trace = partial (
708- az . from_dict ,
669+ idata = az . from_dict (
670+ posterior = mcmc_samples ,
709671 log_likelihood = log_likelihood ,
710672 observed_data = find_observations (model ),
711673 constant_data = find_constants (model ),
@@ -714,14 +676,13 @@ def sample_jax_nuts(
714676 dims = dims ,
715677 attrs = make_attrs (attrs , library = library ),
716678 posterior_attrs = make_attrs (attrs , library = library ),
679+ ** idata_kwargs ,
717680 )
718- az_trace = to_trace (posterior = mcmc_samples , ** idata_kwargs )
719681
720682 if compute_convergence_checks :
721- warns = run_convergence_checks (az_trace , model )
722- log_warnings (warns )
683+ log_warnings (run_convergence_checks (idata , model ))
723684
724- return az_trace
685+ return idata
725686
726687
727688sample_numpyro_nuts = partial (sample_jax_nuts , nuts_sampler = "numpyro" )
0 commit comments