Skip to content

Commit b71c83b

Browse files
add PyStan3 reloo example [docs] (#1583)
* add pystan3 reloo example * add nest_asyncio * fix class imports * fix super call * update workbook * update wrappers * move pystan refitting files * update sampling wrappers md. * rename pystan3 to pystan * update wrapper notebook * fix mypy * fix missing import * black * update doc references Co-authored-by: Oriol (ZBook) <[email protected]>
1 parent 1b5d24f commit b71c83b

File tree

8 files changed

+693
-91
lines changed

8 files changed

+693
-91
lines changed

Diff for: arviz/wrappers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Sampling wrappers."""
22
from .base import SamplingWrapper
3-
from .wrap_pystan import PyStanSamplingWrapper
3+
from .wrap_stan import PyStan2SamplingWrapper, PyStanSamplingWrapper
44

5-
__all__ = ["SamplingWrapper", "PyStanSamplingWrapper"]
5+
__all__ = ["SamplingWrapper", "PyStan2SamplingWrapper", "PyStanSamplingWrapper"]

Diff for: arviz/wrappers/base.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=too-many-instance-attributes,too-many-arguments
12
"""Base class for sampling wrappers."""
23
from xarray import apply_ufunc
34

@@ -230,11 +231,9 @@ def check_implemented_methods(self, methods):
230231
if method in supported_methods_1arg:
231232
if self._check_method_is_implemented(method, 1):
232233
continue
233-
else:
234-
not_implemented.append(method)
234+
not_implemented.append(method)
235235
elif method in supported_methods_2args:
236236
if self._check_method_is_implemented(method, 1, 1):
237237
continue
238-
else:
239-
not_implemented.append(method)
238+
not_implemented.append(method)
240239
return not_implemented

Diff for: arviz/wrappers/wrap_pystan.py renamed to arviz/wrappers/wrap_stan.py

+61-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# pylint: disable=arguments-differ
22
"""Base class for PyStan wrappers."""
3-
from ..data import from_pystan
3+
from typing import Union
4+
5+
from ..data import from_cmdstanpy, from_pystan
46
from .base import SamplingWrapper
57

68

7-
class PyStanSamplingWrapper(SamplingWrapper):
8-
"""PyStan sampling wrapper base class.
9+
# pylint: disable=abstract-method
10+
class StanSamplingWrapper(SamplingWrapper):
11+
"""Stan sampling wrapper base class.
912
1013
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
1114
description. An example of ``PyStanSamplingWrapper`` usage can be found
@@ -47,16 +50,65 @@ def sel_observations(self, idx):
4750
"""
4851
raise NotImplementedError("sel_observations must be implemented on a model basis")
4952

50-
def sample(self, modified_observed_data):
51-
"""Resample the PyStan model stored in self.model on modified_observed_data."""
52-
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
53-
return fit
54-
5553
def get_inference_data(self, fit):
5654
"""Convert the fit object returned by ``self.sample`` to InferenceData."""
57-
idata = from_pystan(posterior=fit, **self.idata_kwargs)
55+
if fit.__class__.__name__ == "CmdStanMCMC":
56+
idata = from_cmdstanpy(posterior=fit, **self.idata_kwargs)
57+
else:
58+
idata = from_pystan(posterior=fit, **self.idata_kwargs)
5859
return idata
5960

6061
def log_likelihood__i(self, excluded_obs_log_like, idata__i):
6162
"""Retrieve the log likelihood of the excluded observations from ``idata__i``."""
6263
return idata__i.log_likelihood[excluded_obs_log_like]
64+
65+
66+
class PyStan2SamplingWrapper(StanSamplingWrapper):
67+
"""PyStan (2.x) sampling wrapper base class.
68+
69+
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
70+
description. An example of ``PyStanSamplingWrapper`` usage can be found
71+
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
72+
see the user guide pages on :ref:`wrapper_guide`.
73+
74+
Warnings
75+
--------
76+
Sampling wrappers are an experimental feature in a very early stage. Please use them
77+
with caution.
78+
79+
See Also
80+
--------
81+
SamplingWrapper
82+
"""
83+
84+
def sample(self, modified_observed_data):
85+
"""Resample the PyStan model stored in self.model on modified_observed_data."""
86+
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
87+
return fit
88+
89+
90+
class PyStanSamplingWrapper(StanSamplingWrapper):
91+
"""PyStan (3.0+) sampling wrapper base class.
92+
93+
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
94+
description. An example of ``PyStan3SamplingWrapper`` usage can be found
95+
in the :ref:`pystan3_refitting` notebook.
96+
97+
Warnings
98+
--------
99+
Sampling wrappers are an experimental feature in a very early stage. Please use them
100+
with caution.
101+
"""
102+
103+
def sample(self, modified_observed_data):
104+
"""Rebuild and resample the PyStan model on modified_observed_data."""
105+
import stan # pylint: disable=import-error,import-outside-toplevel
106+
107+
self.model: Union[str, stan.Model]
108+
if isinstance(self.model, str):
109+
program_code = self.model
110+
else:
111+
program_code = self.model.program_code
112+
self.model = stan.build(program_code, data=modified_observed_data)
113+
fit = self.model.sample(**self.sample_kwargs)
114+
return fit

Diff for: doc/source/api/wrappers.rst

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ Experimental feature
1111

1212
SamplingWrapper
1313
PyStanSamplingWrapper
14+
PyStan2SamplingWrapper

Diff for: doc/source/user_guide/pystan2_refitting.ipynb

+460
Large diffs are not rendered by default.

Diff for: doc/source/user_guide/pystan_refitting_xr_lik.ipynb renamed to doc/source/user_guide/pystan2_refitting_xr_lik.ipynb

+10-10
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"(pystan_refitting_xr)=\n",
8-
"# Refitting PyStan models with ArviZ (and xarray)\n",
7+
"(pystan2_refitting_xr)=\n",
8+
"# Refitting PyStan (2.x) models with ArviZ (and xarray)\n",
99
"\n",
1010
"ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper`s to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used."
1111
]
@@ -14,7 +14,7 @@
1414
"cell_type": "markdown",
1515
"metadata": {},
1616
"source": [
17-
"Below there is one example of `SamplingWrapper` usage for PyStan."
17+
"Below there is one example of `SamplingWrapper` usage for PyStan (2.x)."
1818
]
1919
},
2020
{
@@ -124,7 +124,7 @@
124124
" vector[N] y_hat;\n",
125125
" \n",
126126
" for (i in 1:N) {\n",
127-
" // pointwise log likelihood will be calculated outside stan, \n",
127+
" // pointwise log likelihood will be calculated outside Stan, \n",
128128
" // posterior predictive however will be generated here, there are \n",
129129
" // no restrictions on adding more generated quantities\n",
130130
" y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e);\n",
@@ -195,7 +195,7 @@
195195
"source": [
196196
"We are now missing the `log_likelihood` group because we have not used the `log_likelihood` argument in `idata_kwargs`. We are doing this to ease the job of the sampling wrapper. Instead of going out of our way to get Stan to calculate the pointwise log likelihood values for each refit and for the excluded observation at every refit, we will compromise and manually write a function to calculate the pointwise log likelihood.\n",
197197
"\n",
198-
"Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan code to the Python code, in both cases we need to manyally write the function to calculate the pointwise log likelihood.\n",
198+
"Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan code to the Python code, in both cases we need to manually write the function to calculate the pointwise log likelihood.\n",
199199
"\n",
200200
"Moreover, the Python computation could even be written to be compatible with Dask. Thus it will work even in cases where the large number of observations makes it impossible to store pointwise log likelihood values (with shape `n_samples * n_observations`) in memory."
201201
]
@@ -3181,14 +3181,14 @@
31813181
"cell_type": "markdown",
31823182
"metadata": {},
31833183
"source": [
3184-
"We will create a subclass of {class}`~arviz.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement `sel_observations` (we are cloning `sample` and `get_inference_data` from the `PyStanSamplingWrapper` in order to use `apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n",
3184+
"We will create a subclass of {class}`~arviz.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement `sel_observations` (we are cloning `sample` and `get_inference_data` from the `PyStan2SamplingWrapper` in order to use `apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n",
31853185
"\n",
31863186
"Note that of the 2 outputs of `sel_observations`, `data__i` is a dictionary because it is an argument of `sample` which will pass it as is to `model.sampling`, whereas `data_ex` is a list because it is an argument to `log_likelihood__i` which will pass it as `*data_ex` to `apply_ufunc`. More on `data_ex` and `apply_ufunc` integration below."
31873187
]
31883188
},
31893189
{
31903190
"cell_type": "code",
3191-
"execution_count": 12,
3191+
"execution_count": null,
31923192
"metadata": {},
31933193
"outputs": [],
31943194
"source": [
@@ -3202,7 +3202,7 @@
32023202
" return data__i, data_ex\n",
32033203
" \n",
32043204
" def sample(self, modified_observed_data):\n",
3205-
" #Cloned from PyStanSamplingWrapper.\n",
3205+
" #Cloned from PyStan2SamplingWrapper.\n",
32063206
" fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)\n",
32073207
" return fit\n",
32083208
"\n",
@@ -3265,7 +3265,7 @@
32653265
"cell_type": "markdown",
32663266
"metadata": {},
32673267
"source": [
3268-
"We initialize our sampling wrapper. Let's stop and analize each of the arguments. \n",
3268+
"We initialize our sampling wrapper. Let's stop and analyze each of the arguments. \n",
32693269
"\n",
32703270
"We then use the `log_lik_fun` and `posterior_vars` argument to tell the wrapper how to call {func}`~xarray:xarray.apply_ufunc`. `log_lik_fun` is the function to be called, which is then called with the following positional arguments:\n",
32713271
"\n",
@@ -3419,5 +3419,5 @@
34193419
}
34203420
},
34213421
"nbformat": 4,
3422-
"nbformat_minor": 2
3422+
"nbformat_minor": 4
34233423
}

Diff for: doc/source/user_guide/pystan_refitting.ipynb

+154-65
Large diffs are not rendered by default.

Diff for: doc/source/user_guide/sampling_wrappers.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ whereas the second one externalizes the computation of the pointwise log
1111
likelihood to the user who is expected to write it with xarray/numpy.
1212

1313
```{toctree}
14+
pystan2_refitting
1415
pystan_refitting
1516
pymc3_refitting
1617
numpyro_refitting
17-
pystan_refitting_xr_lik
18+
pystan2_refitting_xr_lik
1819
pymc3_refitting_xr_lik
1920
numpyro_refitting_xr_lik
2021
```

0 commit comments

Comments
 (0)