Skip to content

Commit 22c079c

Browse files
StephenHoggStephen Hoggtwiecki
authored
Test model logp before starting any MCMC chains (#4211)
* Re-create branch * Fix merge conflict * bug fix * remove unneeded import * Update pymc3/util.py as per twiecki Co-authored-by: Thomas Wiecki <[email protected]> * fix test_examples.py * fix test_step.py * remove unnecessary import Co-authored-by: Stephen Hogg <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 27b1a7d commit 22c079c

File tree

9 files changed

+102
-34
lines changed

9 files changed

+102
-34
lines changed

RELEASE-NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
4242
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
4343
- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129)
4444
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
45+
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
46+
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
4547

4648
### Documentation
4749
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).

pymc3/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ def check_test_point(self, test_point=None, round_vals=2):
13681368
test_point = self.test_point
13691369

13701370
return Series(
1371-
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
1371+
{RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
13721372
name="Log-probability of test_point",
13731373
)
13741374

pymc3/sampling.py

+11
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
PGBART,
5555
)
5656
from .util import (
57+
check_start_vals,
5758
update_start_vals,
5859
get_untransformed_name,
5960
is_transformed_name,
@@ -419,7 +420,16 @@ def sample(
419420
420421
"""
421422
model = modelcontext(model)
423+
if start is None:
424+
start = model.test_point
425+
else:
426+
if isinstance(start, dict):
427+
update_start_vals(start, model.test_point, model)
428+
else:
429+
for chain_start_vals in start:
430+
update_start_vals(chain_start_vals, model.test_point, model)
422431

432+
check_start_vals(start, model)
423433
if cores is None:
424434
cores = min(4, _cpu_count())
425435

@@ -487,6 +497,7 @@ def sample(
487497
progressbar=progressbar,
488498
**kwargs,
489499
)
500+
check_start_vals(start_, model)
490501
if start is None:
491502
start = start_
492503
except (AttributeError, NotImplementedError, tg.NullTypeGradError):

pymc3/tests/test_examples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def build_model(self):
274274
# Estimated mean count
275275
theta = pm.Uniform("theta", 0, 100)
276276
# Poisson likelihood
277-
pm.ZeroInflatedPoisson("y", theta, psi, observed=self.y)
277+
pm.ZeroInflatedPoisson("y", psi, theta, observed=self.y)
278278
return model
279279

280280
def test_run(self):

pymc3/tests/test_hmc.py

-15
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
from . import models
1919
from pymc3.step_methods.hmc.base_hmc import BaseHMC
20-
from pymc3.exceptions import SamplingError
2120
import pymc3
22-
import pytest
2321
import logging
2422
from pymc3.theanof import floatX
2523

@@ -57,16 +55,3 @@ def test_nuts_tuning():
5755

5856
assert not step.tune
5957
assert np.all(trace["step_size"][5:] == trace["step_size"][5])
60-
61-
62-
def test_nuts_error_reporting(caplog):
63-
model = pymc3.Model()
64-
with caplog.at_level(logging.CRITICAL) and pytest.raises(SamplingError):
65-
with model:
66-
pymc3.HalfNormal("a", sigma=1, transform=None, testval=-1)
67-
pymc3.HalfNormal("b", sigma=1, transform=None)
68-
trace = pymc3.sample(init="adapt_diag", chains=1)
69-
assert (
70-
"Bad initial energy, check any log probabilities that are inf or -inf: a -inf\nb"
71-
in caplog.text
72-
)

pymc3/tests/test_step.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
simple_2model_continuous,
2828
)
2929
from pymc3.sampling import assign_step_methods, sample
30-
from pymc3.parallel_sampling import ParallelSamplingError
3130
from pymc3.exceptions import SamplingError
3231
from pymc3.model import Model, Potential, set_data
3332

@@ -963,15 +962,15 @@ def test_bad_init_nonparallel(self):
963962
HalfNormal("a", sigma=1, testval=-1, transform=None)
964963
with pytest.raises(SamplingError) as error:
965964
sample(init=None, chains=1, random_seed=1)
966-
error.match("Bad initial")
965+
error.match("Initial evaluation")
967966

968967
@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher")
969968
def test_bad_init_parallel(self):
970969
with Model():
971970
HalfNormal("a", sigma=1, testval=-1, transform=None)
972-
with pytest.raises(ParallelSamplingError) as error:
971+
with pytest.raises(SamplingError) as error:
973972
sample(init=None, cores=2, random_seed=1)
974-
error.match("Bad initial")
973+
error.match("Initial evaluation")
975974

976975
def test_linalg(self, caplog):
977976
with Model():

pymc3/tests/test_util.py

+34
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,40 @@ def test_soft_update_parent(self):
9595
assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"])
9696

9797

98+
class TestCheckStartVals(SeededTest):
99+
def setup_method(self):
100+
super().setup_method()
101+
102+
def test_valid_start_point(self):
103+
with pm.Model() as model:
104+
a = pm.Uniform("a", lower=0.0, upper=1.0)
105+
b = pm.Uniform("b", lower=2.0, upper=3.0)
106+
107+
start = {"a": 0.3, "b": 2.1}
108+
pm.util.update_start_vals(start, model.test_point, model)
109+
pm.util.check_start_vals(start, model)
110+
111+
def test_invalid_start_point(self):
112+
with pm.Model() as model:
113+
a = pm.Uniform("a", lower=0.0, upper=1.0)
114+
b = pm.Uniform("b", lower=2.0, upper=3.0)
115+
116+
start = {"a": np.nan, "b": np.nan}
117+
pm.util.update_start_vals(start, model.test_point, model)
118+
with pytest.raises(pm.exceptions.SamplingError):
119+
pm.util.check_start_vals(start, model)
120+
121+
def test_invalid_variable_name(self):
122+
with pm.Model() as model:
123+
a = pm.Uniform("a", lower=0.0, upper=1.0)
124+
b = pm.Uniform("b", lower=2.0, upper=3.0)
125+
126+
start = {"a": 0.3, "b": 2.1, "c": 1.0}
127+
pm.util.update_start_vals(start, model.test_point, model)
128+
with pytest.raises(KeyError):
129+
pm.util.check_start_vals(start, model)
130+
131+
98132
class TestExceptions:
99133
def test_shape_error(self):
100134
with pytest.raises(pm.exceptions.ShapeError) as exinfo:

pymc3/tuning/starting.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..theanof import inputvars
2929
import theano.gradient as tg
3030
from ..blocking import DictToArrayBijection, ArrayOrdering
31-
from ..util import update_start_vals, get_default_varnames, get_var_name
31+
from ..util import check_start_vals, update_start_vals, get_default_varnames, get_var_name
3232

3333
import warnings
3434
from inspect import getargspec
@@ -89,13 +89,7 @@ def find_MAP(
8989
else:
9090
update_start_vals(start, model.test_point, model)
9191

92-
if not set(start.keys()).issubset(model.named_vars.keys()):
93-
extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys()))
94-
valid_keys = ", ".join(model.named_vars.keys())
95-
raise KeyError(
96-
"Some start parameters do not appear in the model!\n"
97-
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
98-
)
92+
check_start_vals(start, model)
9993

10094
if vars is None:
10195
vars = model.cont_vars

pymc3/util.py

+48-5
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import functools
1717
from typing import List, Dict, Tuple, Union
1818

19+
import numpy as np
1920
import xarray
2021
import arviz
21-
from numpy import ndarray
2222

23+
from pymc3.exceptions import SamplingError
2324
from theano.tensor import TensorVariable
2425

2526

@@ -188,6 +189,48 @@ def update_start_vals(a, b, model):
188189
a.update({k: v for k, v in b.items() if k not in a})
189190

190191

192+
def check_start_vals(start, model):
193+
r"""Check that the starting values for MCMC do not cause the relevant log probability
194+
to evaluate to something invalid (e.g. Inf or NaN)
195+
196+
Parameters
197+
----------
198+
start : dict, or array of dict
199+
Starting point in parameter space (or partial point)
200+
Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
201+
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
202+
overwrite the default.
203+
model : Model object
204+
Raises
205+
______
206+
KeyError if the parameters provided by `start` do not agree with the parameters contained
207+
within `model`
208+
pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an
209+
invalid (i.e. non-finite) state
210+
Returns
211+
-------
212+
None
213+
"""
214+
start_points = [start] if isinstance(start, dict) else start
215+
for elem in start_points:
216+
if not set(elem.keys()).issubset(model.named_vars.keys()):
217+
extra_keys = ", ".join(set(elem.keys()) - set(model.named_vars.keys()))
218+
valid_keys = ", ".join(model.named_vars.keys())
219+
raise KeyError(
220+
"Some start parameters do not appear in the model!\n"
221+
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
222+
)
223+
224+
initial_eval = model.check_test_point(test_point=elem)
225+
226+
if not np.all(np.isfinite(initial_eval)):
227+
raise SamplingError(
228+
"Initial evaluation of model at starting point failed!\n"
229+
"Starting values:\n{}\n\n"
230+
"Initial evaluation results:\n{}".format(elem, str(initial_eval))
231+
)
232+
233+
191234
def get_transformed(z):
192235
if hasattr(z, "transformed"):
193236
z = z.transformed
@@ -214,13 +257,13 @@ def enhanced(*args, **kwargs):
214257

215258
# FIXME: this function is poorly named, because it returns a LIST of
216259
# points, not a dictionary of points.
217-
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
260+
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
218261
# grab posterior samples for each variable
219-
_samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()}
262+
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
220263
# make dicts
221-
points: List[Dict[str, ndarray]] = []
264+
points: List[Dict[str, np.ndarray]] = []
222265
vn: str
223-
s: ndarray
266+
s: np.ndarray
224267
for c in ds.chain:
225268
for d in ds.draw:
226269
points.append({vn: s[c, d] for vn, s in _samples.items()})

0 commit comments

Comments
 (0)