Skip to content

Commit e46f490

Browse files
martiningrammichaelosthegejunpenglao
authored
Transform jax samples (#4427)
* Transform samples from sample_numpyro_nuts * Add `pymc3.sampling_jax._transform_samples` function which transforms draws * Modify `pymc3.sampling_jax.sample_numpyro_nuts` function to use this function to return transformed samples * Add release note * Update pymc3/sampling_jax.py Co-authored-by: Junpeng Lao <[email protected]> * Added a small test * Split jax tests into their own workflow Co-authored-by: Michael Osthege <[email protected]> Co-authored-by: Junpeng Lao <[email protected]>
1 parent 41a25d5 commit e46f490

File tree

6 files changed

+142
-10
lines changed

6 files changed

+142
-10
lines changed

Diff for: .github/workflows/jaxtests.yml

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
name: jax-sampling
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [master]
7+
8+
jobs:
9+
pytest:
10+
strategy:
11+
matrix:
12+
os: [ubuntu-latest]
13+
floatx: [float64]
14+
test-subset:
15+
- pymc3/tests/test_sampling_jax.py
16+
fail-fast: false
17+
runs-on: ${{ matrix.os }}
18+
env:
19+
TEST_SUBSET: ${{ matrix.test-subset }}
20+
THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
21+
defaults:
22+
run:
23+
shell: bash -l {0}
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Cache conda
27+
uses: actions/cache@v1
28+
env:
29+
# Increase this value to reset cache if environment-dev-py39.yml has not changed
30+
CACHE_NUMBER: 0
31+
with:
32+
path: ~/conda_pkgs_dir
33+
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
34+
hashFiles('conda-envs/environment-dev-py39.yml') }}
35+
- name: Cache multiple paths
36+
uses: actions/cache@v2
37+
env:
38+
# Increase this value to reset cache if requirements.txt has not changed
39+
CACHE_NUMBER: 0
40+
with:
41+
path: |
42+
~/.cache/pip
43+
$RUNNER_TOOL_CACHE/Python/*
44+
~\AppData\Local\pip\Cache
45+
key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{
46+
hashFiles('requirements.txt') }}
47+
- uses: conda-incubator/setup-miniconda@v2
48+
with:
49+
activate-environment: pymc3-dev-py39
50+
channel-priority: strict
51+
environment-file: conda-envs/environment-dev-py39.yml
52+
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
53+
- name: Install pymc3
54+
run: |
55+
conda activate pymc3-dev-py39
56+
pip install -e .
57+
python --version
58+
- name: Install jax specific dependencies
59+
run: |
60+
conda activate pymc3-dev-py39
61+
pip install numpyro tensorflow_probability
62+
- name: Run tests
63+
run: |
64+
python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET

Diff for: .github/workflows/pytest.yml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
--ignore=pymc3/tests/test_quadpotential.py
2828
--ignore=pymc3/tests/test_random.py
2929
--ignore=pymc3/tests/test_sampling.py
30+
--ignore=pymc3/tests/test_sampling_jax.py
3031
--ignore=pymc3/tests/test_shape_handling.py
3132
--ignore=pymc3/tests/test_shared.py
3233
--ignore=pymc3/tests/test_smc.py

Diff for: RELEASE-NOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
### Breaking Changes
66

77
### New Features
8-
+ Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
8+
- Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
9+
- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)).
910

1011
### Maintenance
1112
- We upgraded to `Theano-PyMC v1.1.2` which [includes bugfixes](https://github.com/pymc-devs/aesara/compare/rel-1.1.0...rel-1.1.2) for warning floods and compiledir locking (see [#4444](https://github.com/pymc-devs/pymc3/pull/4444))

Diff for: pymc3/sampling_jax.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import re
44
import warnings
55

6+
from collections import defaultdict
7+
68
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
79
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
810
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
@@ -121,6 +123,7 @@ def sample_numpyro_nuts(
121123
random_seed=10,
122124
model=None,
123125
progress_bar=True,
126+
keep_untransformed=False,
124127
):
125128
from numpyro.infer import MCMC, NUTS
126129

@@ -175,8 +178,48 @@ def _sample(current_state, seed):
175178
# print("Sampling time = ", tic4 - tic3)
176179

177180
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
181+
tic3 = pd.Timestamp.now()
182+
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
183+
tic4 = pd.Timestamp.now()
178184

179185
az_trace = az.from_dict(posterior=posterior)
180-
tic3 = pd.Timestamp.now()
181186
print("Compilation + sampling time = ", tic3 - tic2)
187+
print("Transformation time = ", tic4 - tic3)
188+
182189
return az_trace # , leapfrogs_taken, tic3 - tic2
190+
191+
192+
def _transform_samples(samples, model, keep_untransformed=False):
193+
194+
# Find out which RVs we need to compute:
195+
free_rv_names = {x.name for x in model.free_RVs}
196+
unobserved_names = {x.name for x in model.unobserved_RVs}
197+
198+
names_to_compute = unobserved_names - free_rv_names
199+
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]
200+
201+
# Create function graph for these:
202+
fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)
203+
204+
# Jaxify, which returns a list of functions, one for each op
205+
jax_fns = jax_funcify(fgraph)
206+
207+
# Put together the inputs
208+
inputs = [samples[x.name] for x in model.free_RVs]
209+
210+
for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):
211+
212+
# We need a function taking a single argument to run vmap, while the
213+
# jax_fn takes a list, so:
214+
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)
215+
216+
# Add to sample dict
217+
samples[cur_op.name] = result
218+
219+
# Discard unwanted transformed variables, if desired:
220+
vars_to_keep = set(
221+
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
222+
)
223+
samples = {x: y for x, y in samples.items() if x in vars_to_keep}
224+
225+
return samples

Diff for: pymc3/tests/test_sampling_jax.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
import pymc3 as pm
4+
5+
from pymc3.sampling_jax import sample_numpyro_nuts
6+
7+
8+
def test_transform_samples():
9+
10+
with pm.Model() as model:
11+
12+
sigma = pm.HalfNormal("sigma")
13+
b = pm.Normal("b", sigma=sigma)
14+
trace = sample_numpyro_nuts(keep_untransformed=True)
15+
16+
log_vals = trace.posterior["sigma_log__"].values
17+
trans_vals = trace.posterior["sigma"].values
18+
19+
assert np.allclose(np.exp(log_vals), trans_vals)

Diff for: scripts/check_all_tests_are_covered.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
from pathlib import Path
1313

1414
if __name__ == "__main__":
15-
pytest_ci_job = Path(".github") / "workflows/pytest.yml"
16-
txt = pytest_ci_job.read_text()
17-
ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
18-
non_ignored_tests = set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt))
15+
testing_workflows = ["jaxtests.yml", "pytest.yml"]
16+
ignored = set()
17+
non_ignored = set()
18+
for wfyml in testing_workflows:
19+
pytest_ci_job = Path(".github") / "workflows" / wfyml
20+
txt = pytest_ci_job.read_text()
21+
ignored = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
22+
non_ignored = non_ignored.union(set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt)))
1923
assert (
20-
ignored_tests <= non_ignored_tests
21-
), f"The following tests are ignored by the first job but not run by the others: {ignored_tests.difference(non_ignored_tests)}"
24+
ignored <= non_ignored
25+
), f"The following tests are ignored by the first job but not run by the others: {ignored.difference(non_ignored)}"
2226
assert (
23-
ignored_tests >= non_ignored_tests
24-
), f"The following tests are run by multiple jobs: {non_ignored_tests.difference(ignored_tests)}"
27+
ignored >= non_ignored
28+
), f"The following tests are run by multiple jobs: {non_ignored.difference(ignored)}"

0 commit comments

Comments
 (0)