Skip to content

Commit fc09a79

Browse files
authored
Michaelosthege/issue 3828 (#2)
* xarray test for fast posterior predictive sampling. * Move Dataset translation to util. The translation from xarray Dataset to a list of points was previously open-coded into sample_posterior_predictive. Pulled it out so it can be used in both spp and fast_sample_posterior_predictive. * fast_sample_posterior_predictive support for xarray traces.
1 parent 796c9bb commit fc09a79

File tree

6 files changed

+1292
-18
lines changed

6 files changed

+1292
-18
lines changed

docs/source/build.out

+1,059
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import pymc3"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 2,
15+
"metadata": {},
16+
"outputs": [
17+
{
18+
"name": "stderr",
19+
"output_type": "stream",
20+
"text": [
21+
"Auto-assigning NUTS sampler...\n",
22+
"Initializing NUTS using jitter+adapt_diag...\n",
23+
"Multiprocess sampling (4 chains in 4 jobs)\n",
24+
"NUTS: [n]\n"
25+
]
26+
},
27+
{
28+
"data": {
29+
"text/html": [
30+
"\n",
31+
" <div>\n",
32+
" <style>\n",
33+
" /* Turns off some styling */\n",
34+
" progress {\n",
35+
" /* gets rid of default border in Firefox and Opera. */\n",
36+
" border: none;\n",
37+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
38+
" background-size: auto;\n",
39+
" }\n",
40+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
41+
" background: #F44336;\n",
42+
" }\n",
43+
" </style>\n",
44+
" <progress value='4000' class='' max='4000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
45+
" 100.00% [4000/4000 00:00<00:00 Sampling 4 chains, 0 divergences]\n",
46+
" </div>\n",
47+
" "
48+
],
49+
"text/plain": [
50+
"<IPython.core.display.HTML object>"
51+
]
52+
},
53+
"metadata": {},
54+
"output_type": "display_data"
55+
},
56+
{
57+
"name": "stderr",
58+
"output_type": "stream",
59+
"text": [
60+
"The acceptance probability does not match the target. It is 0.8816556891941705, but should be close to 0.8. Try to increase the number of tuning steps.\n"
61+
]
62+
}
63+
],
64+
"source": [
65+
"with pymc3.Model() as pmodel:\n",
66+
" n = pymc3.Normal('n')\n",
67+
" trace = pymc3.sample()\n",
68+
"\n",
69+
"with pmodel:\n",
70+
" d = pymc3.Deterministic('d', n * 4)"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 3,
76+
"metadata": {},
77+
"outputs": [
78+
{
79+
"name": "stdout",
80+
"output_type": "stream",
81+
"text": [
82+
"CPU times: user 18.8 ms, sys: 3.42 ms, total: 22.2 ms\n",
83+
"Wall time: 23.8 ms\n"
84+
]
85+
}
86+
],
87+
"source": [
88+
"%%time\n",
89+
"with pmodel:\n",
90+
" pp = pymc3.fast_sample_posterior_predictive(\n",
91+
" [trace[15]],\n",
92+
" var_names=['d']\n",
93+
" )"
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": 4,
99+
"metadata": {},
100+
"outputs": [
101+
{
102+
"ename": "AttributeError",
103+
"evalue": "'list' object has no attribute '_straces'",
104+
"output_type": "error",
105+
"traceback": [
106+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
107+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
108+
"\u001b[0;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n",
109+
"\u001b[0;32m~/src/pymc3/pymc3/sampling.py\u001b[0m in \u001b[0;36msample_posterior_predictive\u001b[0;34m(trace, samples, model, vars, var_names, size, keep_size, random_seed, progressbar)\u001b[0m\n\u001b[1;32m 1539\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1540\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msamples\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1541\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_straces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1543\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msamples\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen_trace\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnchain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
110+
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute '_straces'"
111+
]
112+
}
113+
],
114+
"source": [
115+
"%%time\n",
116+
"with pmodel:\n",
117+
" pp = pymc3.sample_posterior_predictive(\n",
118+
" [trace[15]],\n",
119+
" var_names=['d']\n",
120+
" )"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 5,
126+
"metadata": {},
127+
"outputs": [
128+
{
129+
"data": {
130+
"text/plain": [
131+
"{'n': 0.691903087470128}"
132+
]
133+
},
134+
"execution_count": 5,
135+
"metadata": {},
136+
"output_type": "execute_result"
137+
}
138+
],
139+
"source": [
140+
"trace[15]"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": 8,
146+
"metadata": {},
147+
"outputs": [
148+
{
149+
"data": {
150+
"text/plain": [
151+
"True"
152+
]
153+
},
154+
"execution_count": 8,
155+
"metadata": {},
156+
"output_type": "execute_result"
157+
}
158+
],
159+
"source": [
160+
"'MultiTrace' in dir(pymc3.backends.base)"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": null,
166+
"metadata": {},
167+
"outputs": [],
168+
"source": []
169+
}
170+
],
171+
"metadata": {
172+
"kernelspec": {
173+
"display_name": "Python 3",
174+
"language": "python",
175+
"name": "python3"
176+
},
177+
"language_info": {
178+
"codemirror_mode": {
179+
"name": "ipython",
180+
"version": 3
181+
},
182+
"file_extension": ".py",
183+
"mimetype": "text/x-python",
184+
"name": "python",
185+
"nbconvert_exporter": "python",
186+
"pygments_lexer": "ipython3",
187+
"version": "3.7.6"
188+
}
189+
},
190+
"nbformat": 4,
191+
"nbformat_minor": 2
192+
}

pymc3/distributions/posterior_predictive.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
import numpy as np
1313
import theano
1414
import theano.tensor as tt
15+
from xarray import Dataset
1516

1617
from ..backends.base import MultiTrace #, TraceLike, TraceDict
1718
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
1819
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
1920
from ..exceptions import IncorrectArgumentsError
2021
from ..vartypes import theano_constant
22+
from ..util import dataset_to_point_dict
2123
# Failing tests:
2224
# test_mixture_random_shape::test_mixture_random_shape
2325
#
@@ -119,7 +121,7 @@ def __getitem__(self, item):
119121

120122

121123

122-
def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
124+
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
123125
samples: Optional[int]=None,
124126
model: Optional[Model]=None,
125127
var_names: Optional[List[str]]=None,
@@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
135137
136138
Parameters
137139
----------
138-
trace : MultiTrace or List of points
140+
trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
139141
Trace generated from MCMC sampling.
140142
samples : int, optional
141143
Number of posterior predictive samples to generate. Defaults to one posterior predictive
@@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
168170
### greater than the number of samples in the trace parameter, we sample repeatedly. This
169171
### makes the shape issues just a little easier to deal with.
170172

173+
if isinstance(trace, Dataset):
174+
trace = dataset_to_point_dict(trace)
175+
171176
model = modelcontext(model)
172177
assert model is not None
173178
with model:

pymc3/sampling.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
get_untransformed_name,
5555
is_transformed_name,
5656
get_default_varnames,
57+
dataset_to_point_dict,
5758
)
5859
from .vartypes import discrete_types
5960
from .exceptions import IncorrectArgumentsError
@@ -1558,21 +1559,7 @@ def sample_posterior_predictive(
15581559
posterior predictive samples.
15591560
"""
15601561
if isinstance(trace, xarray.Dataset):
1561-
# grab posterior samples for each variable
1562-
_samples = {
1563-
vn : trace[vn].values
1564-
for vn in trace.keys()
1565-
}
1566-
# make dicts
1567-
points = []
1568-
for c in trace.chain:
1569-
for d in trace.draw:
1570-
points.append({
1571-
vn : s[c, d]
1572-
for vn, s in _samples.items()
1573-
})
1574-
# use the list of points
1575-
trace = points
1562+
trace = dataset_to_point_dict(trace)
15761563

15771564
len_trace = len(trace)
15781565
try:

pymc3/tests/test_sampling.py

+9
Original file line numberDiff line numberDiff line change
@@ -901,3 +901,12 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
901901
idat.posterior,
902902
var_names=['d']
903903
)
904+
905+
def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
906+
pmodel, trace = point_list_arg_bug_fixture
907+
idat = az.from_pymc3(trace)
908+
with pmodel:
909+
pp = pm.fast_sample_posterior_predictive(
910+
idat.posterior,
911+
var_names=['d']
912+
)

pymc3/util.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
import re
1616
import functools
17-
from numpy import asscalar
17+
from typing import List, Dict
18+
19+
import xarray
20+
from numpy import asscalar, ndarray
21+
1822

1923
LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
2024

@@ -179,3 +183,21 @@ def enhanced(*args, **kwargs):
179183
newwrapper = functools.partial(wrapper, *args, **kwargs)
180184
return newwrapper
181185
return enhanced
186+
187+
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
188+
# grab posterior samples for each variable
189+
_samples = {
190+
vn : ds[vn].values
191+
for vn in ds.keys()
192+
}
193+
# make dicts
194+
points = []
195+
for c in ds.chain:
196+
for d in ds.draw:
197+
points.append({
198+
vn : s[c, d]
199+
for vn, s in _samples.items()
200+
})
201+
# use the list of points
202+
ds = points
203+
return ds

0 commit comments

Comments
 (0)