Skip to content

Add constant and observed data to nutpie idata #6943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@
import numpy as np
import pytensor.gradient as tg

from arviz import InferenceData
from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from fastprogress.fastprogress import progress_bar
from pytensor.graph.basic import Variable
from typing_extensions import Protocol, TypeAlias

import pymc as pm

from pymc.backends import RunType, TraceOrBackend, init_traces
from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
Expand Down Expand Up @@ -293,8 +299,24 @@ def _sample_external_nuts(
"`idata_kwargs` are currently ignored by the nutpie sampler",
UserWarning,
)

# gather observed and constant data as nutpie.sample() has no access to the PyMC model
coords, dims = coords_and_dims_for_inferencedata(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done after sampling (if sampling fails, there's no point in doing this work in advance).

Please add a comment along the lines of # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very cheap function call:

pymc/pymc/backends/arviz.py

Lines 103 to 112 in 6f4a040

def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}
return coords, dims

And sometimes this step catches issues with coordinate value types, that's why I recommended to put it before sampling (just like gathering constant and observed data)

constant_data = dict_to_dataset(
find_constants(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)
observed_data = dict_to_dataset(
find_observations(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)
compiled_model = nutpie.compile_pymc_model(model)
t_start = time.time()
idata = nutpie.sample(
compiled_model,
draws=draws,
Expand All @@ -305,6 +327,20 @@ def _sample_external_nuts(
progress_bar=progressbar,
**nuts_sampler_kwargs,
)
t_sample = time.time() - t_start
attrs = make_attrs(
{
"sampling_time": t_sample,
},
library=nutpie,
)
for k, v in attrs.items():
idata.posterior.attrs[k] = v
idata.add_groups(
{"constant_data": constant_data, "observed_data": observed_data},
coords=coords,
dims=dims,
)
return idata

elif sampler == "numpyro":
Expand Down
12 changes: 9 additions & 3 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy.testing as npt
import pytest

from pymc import Model, Normal, sample
from pymc import ConstantData, Model, Normal, sample


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
Expand All @@ -25,7 +25,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
pytest.importorskip(nuts_sampler)

with Model():
Normal("x")
x = Normal("x", 100, 5)
y = ConstantData("y", [1, 2, 3, 4])
ConstantData("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

kwargs = dict(
nuts_sampler=nuts_sampler,
Expand Down Expand Up @@ -55,7 +59,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
)
)
assert warns == expected

assert "y" in idata1.constant_data
assert "z" in idata1.constant_data
assert "L" in idata1.observed_data
assert idata1.posterior.chain.size == 2
assert idata1.posterior.draw.size == 500
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
Expand Down