Skip to content

Transform jax samples #4427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
### Breaking Changes

### New Features
+ Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)).

### Maintenance
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
Expand Down
45 changes: 44 additions & 1 deletion pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import warnings

from collections import defaultdict

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
Expand Down Expand Up @@ -121,6 +123,7 @@ def sample_numpyro_nuts(
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
from numpyro.infer import MCMC, NUTS

Expand Down Expand Up @@ -175,8 +178,48 @@ def _sample(current_state, seed):
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
tic3 = pd.Timestamp.now()
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
tic4 = pd.Timestamp.now()

az_trace = az.from_dict(posterior=posterior)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)
print("Transformation time = ", tic4 - tic3)

return az_trace # , leapfrogs_taken, tic3 - tic2


def _transform_samples(samples, model, keep_untransformed=False):

# Find out which RVs we need to compute:
free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]

# Create function graph for these:
fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)

# Jaxify, which returns a list of functions, one for each op
jax_fns = jax_funcify(fgraph)

# Put together the inputs
inputs = [samples[x.name] for x in model.free_RVs]

for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):

# We need a function taking a single argument to run vmap, while the
# jax_fn takes a list, so:
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)

# Add to sample dict
samples[cur_op.name] = result

# Discard unwanted transformed variables, if desired:
vars_to_keep = set(
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
)
samples = {x: y for x, y in samples.items() if x in vars_to_keep}

return samples
19 changes: 19 additions & 0 deletions pymc3/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import pymc3 as pm

from pymc3.sampling_jax import sample_numpyro_nuts


def test_transform_samples():

with pm.Model() as model:

sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", sigma=sigma)
trace = sample_numpyro_nuts(keep_untransformed=True)

log_vals = trace.posterior["sigma_log__"].values
trans_vals = trace.posterior["sigma"].values

assert np.allclose(np.exp(log_vals), trans_vals)