Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
### Breaking Changes

### New Features
+ Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)).

### Maintenance
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
Expand Down
45 changes: 44 additions & 1 deletion pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import warnings

from collections import defaultdict

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
Expand Down Expand Up @@ -121,6 +123,7 @@ def sample_numpyro_nuts(
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
from numpyro.infer import MCMC, NUTS

Expand Down Expand Up @@ -175,8 +178,48 @@ def _sample(current_state, seed):
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
tic3 = pd.Timestamp.now()
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
tic4 = pd.Timestamp.now()

az_trace = az.from_dict(posterior=posterior)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)
print("Transformation time = ", tic4 - tic3)

return az_trace # , leapfrogs_taken, tic3 - tic2


def _transform_samples(samples, model, keep_untransformed=False):

# Find out which RVs we need to compute:
free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]

# Create function graph for these:
fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)

# Jaxify, which returns a list of functions, one for each op
jax_fns = jax_funcify(fgraph)

# Put together the inputs
inputs = [samples[x.name] for x in model.free_RVs]

for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):

# We need a function taking a single argument to run vmap, while the
# jax_fn takes a list, so:
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)

# Add to sample dict
samples[cur_op.name] = result

# Discard unwanted transformed variables, if desired:
vars_to_keep = set(
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
)
samples = {x: y for x, y in samples.items() if x in vars_to_keep}

return samples
19 changes: 19 additions & 0 deletions pymc3/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import pymc3 as pm

from pymc3.sampling_jax import sample_numpyro_nuts


def test_transform_samples():

with pm.Model() as model:

sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", sigma=sigma)
trace = sample_numpyro_nuts(keep_untransformed=True)

log_vals = trace.posterior["sigma_log__"].values
trans_vals = trace.posterior["sigma"].values

assert np.allclose(np.exp(log_vals), trans_vals)