From 7d10905a8e44922ce8e7fb7bf875605b6f3065e8 Mon Sep 17 00:00:00 2001 From: "j.niesser" Date: Fri, 6 Oct 2023 17:00:03 +0200 Subject: [PATCH 1/2] Add constant and observed data to nutpie idata Closes https://github.com/pymc-devs/nutpie/issues/74 --- pymc/sampling/mcmc.py | 40 ++++++++++++++++++++++++++-- tests/sampling/test_mcmc_external.py | 12 ++++++--- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4afb93a7fe..bccc8c8459 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -39,7 +39,8 @@ import numpy as np import pytensor.gradient as tg -from arviz import InferenceData +from arviz import InferenceData, dict_to_dataset +from arviz.data.base import make_attrs from fastprogress.fastprogress import progress_bar from pytensor.graph.basic import Variable from typing_extensions import Protocol, TypeAlias @@ -47,6 +48,11 @@ import pymc as pm from pymc.backends import RunType, TraceOrBackend, init_traces +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -293,8 +299,24 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) - + # gather observed and constant data as nutpie.sample() has no access to the PyMC model + coords, dims = coords_and_dims_for_inferencedata(model) + constant_data = dict_to_dataset( + find_constants(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + observed_data = dict_to_dataset( + find_observations(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) compiled_model = nutpie.compile_pymc_model(model) + t_start = time.time() idata = nutpie.sample( compiled_model, draws=draws, @@ -305,6 +327,20 @@ def _sample_external_nuts( progress_bar=progressbar, **nuts_sampler_kwargs, ) + t_sample = time.time() - t_start + attrs = make_attrs( + { + "sampling_time": t_sample, + }, + library=nutpie, + ) + for k, v in attrs.items(): + idata.posterior.attrs[k] = v + idata.add_groups( + {"constant_data": constant_data, "observed_data": observed_data}, + coords=coords, + dims=dims, + ) return idata elif sampler == "numpyro": diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index fc3e6e1551..d7370a560d 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -16,7 +16,7 @@ import numpy.testing as npt import pytest -from pymc import Model, Normal, sample +from pymc import ConstantData, Model, Normal, sample @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) @@ -25,7 +25,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): pytest.importorskip(nuts_sampler) with Model(): - Normal("x") + x = Normal("x", 100, 5) + y = ConstantData("y", [1, 2, 3, 4]) + ConstantData("z", [100, 190, 310, 405]) + + Normal("L", mu=x, sigma=0.1, observed=y) kwargs = dict( nuts_sampler=nuts_sampler, @@ -55,7 +59,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): ) ) assert warns == expected - + assert "y" in idata1.constant_data + assert "z" in idata1.constant_data + assert "L" in idata1.observed_data assert idata1.posterior.chain.size == 2 assert idata1.posterior.draw.size == 500 np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) From 776ff03ebe451f0ad7755096ae89ff93f8b52ff7 Mon Sep 17 00:00:00 2001 From: "j.niesser" Date: Thu, 12 Oct 2023 13:40:56 +0200 Subject: [PATCH 2/2] change order and add comment to revert --- pymc/sampling/mcmc.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bccc8c8459..1a1fc45e6a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -299,6 +299,20 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) + compiled_model = nutpie.compile_pymc_model(model) + t_start = time.time() + idata = nutpie.sample( + compiled_model, + draws=draws, + tune=tune, + chains=chains, + target_accept=target_accept, + seed=_get_seeds_per_chain(random_seed, 1)[0], + progress_bar=progressbar, + **nuts_sampler_kwargs, + ) + t_sample = time.time() - t_start + # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed # gather observed and constant data as nutpie.sample() has no access to the PyMC model coords, dims = coords_and_dims_for_inferencedata(model) constant_data = dict_to_dataset( @@ -315,19 +329,6 @@ def _sample_external_nuts( dims=dims, default_dims=[], ) - compiled_model = nutpie.compile_pymc_model(model) - t_start = time.time() - idata = nutpie.sample( - compiled_model, - draws=draws, - tune=tune, - chains=chains, - target_accept=target_accept, - seed=_get_seeds_per_chain(random_seed, 1)[0], - progress_bar=progressbar, - **nuts_sampler_kwargs, - ) - t_sample = time.time() - t_start attrs = make_attrs( { "sampling_time": t_sample,