Skip to content
/ pymc Public
  • Sponsor pymc-devs/pymc

  • Notifications You must be signed in to change notification settings
  • Fork 2.1k

Test model logp before starting any MCMC chains #4211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 27, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -42,6 +42,8 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129)
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))

### Documentation
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).
2 changes: 1 addition & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
@@ -1368,7 +1368,7 @@ def check_test_point(self, test_point=None, round_vals=2):
test_point = self.test_point

return Series(
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
{RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
name="Log-probability of test_point",
)

11 changes: 11 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,7 @@
PGBART,
)
from .util import (
check_start_vals,
update_start_vals,
get_untransformed_name,
is_transformed_name,
@@ -419,7 +420,16 @@ def sample(

"""
model = modelcontext(model)
if start is None:
start = model.test_point
else:
if isinstance(start, dict):
update_start_vals(start, model.test_point, model)
else:
for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)

check_start_vals(start, model)
if cores is None:
cores = min(4, _cpu_count())

@@ -487,6 +497,7 @@ def sample(
progressbar=progressbar,
**kwargs,
)
check_start_vals(start_, model)
if start is None:
start = start_
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
2 changes: 1 addition & 1 deletion pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -274,7 +274,7 @@ def build_model(self):
# Estimated mean count
theta = pm.Uniform("theta", 0, 100)
# Poisson likelihood
pm.ZeroInflatedPoisson("y", theta, psi, observed=self.y)
pm.ZeroInflatedPoisson("y", psi, theta, observed=self.y)
return model

def test_run(self):
15 changes: 0 additions & 15 deletions pymc3/tests/test_hmc.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,7 @@

from . import models
from pymc3.step_methods.hmc.base_hmc import BaseHMC
from pymc3.exceptions import SamplingError
import pymc3
import pytest
import logging
from pymc3.theanof import floatX

@@ -57,16 +55,3 @@ def test_nuts_tuning():

assert not step.tune
assert np.all(trace["step_size"][5:] == trace["step_size"][5])


def test_nuts_error_reporting(caplog):
model = pymc3.Model()
with caplog.at_level(logging.CRITICAL) and pytest.raises(SamplingError):
with model:
pymc3.HalfNormal("a", sigma=1, transform=None, testval=-1)
pymc3.HalfNormal("b", sigma=1, transform=None)
trace = pymc3.sample(init="adapt_diag", chains=1)
assert (
"Bad initial energy, check any log probabilities that are inf or -inf: a -inf\nb"
in caplog.text
)
7 changes: 3 additions & 4 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,6 @@
simple_2model_continuous,
)
from pymc3.sampling import assign_step_methods, sample
from pymc3.parallel_sampling import ParallelSamplingError
from pymc3.exceptions import SamplingError
from pymc3.model import Model, Potential, set_data

@@ -963,15 +962,15 @@ def test_bad_init_nonparallel(self):
HalfNormal("a", sigma=1, testval=-1, transform=None)
with pytest.raises(SamplingError) as error:
sample(init=None, chains=1, random_seed=1)
error.match("Bad initial")
error.match("Initial evaluation")

@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher")
def test_bad_init_parallel(self):
with Model():
HalfNormal("a", sigma=1, testval=-1, transform=None)
with pytest.raises(ParallelSamplingError) as error:
with pytest.raises(SamplingError) as error:
sample(init=None, cores=2, random_seed=1)
error.match("Bad initial")
error.match("Initial evaluation")

def test_linalg(self, caplog):
with Model():
34 changes: 34 additions & 0 deletions pymc3/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -95,6 +95,40 @@ def test_soft_update_parent(self):
assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"])


class TestCheckStartVals(SeededTest):
def setup_method(self):
super().setup_method()

def test_valid_start_point(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": 0.3, "b": 2.1}
pm.util.update_start_vals(start, model.test_point, model)
pm.util.check_start_vals(start, model)

def test_invalid_start_point(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": np.nan, "b": np.nan}
pm.util.update_start_vals(start, model.test_point, model)
with pytest.raises(pm.exceptions.SamplingError):
pm.util.check_start_vals(start, model)

def test_invalid_variable_name(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": 0.3, "b": 2.1, "c": 1.0}
pm.util.update_start_vals(start, model.test_point, model)
with pytest.raises(KeyError):
pm.util.check_start_vals(start, model)


class TestExceptions:
def test_shape_error(self):
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
10 changes: 2 additions & 8 deletions pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
from ..theanof import inputvars
import theano.gradient as tg
from ..blocking import DictToArrayBijection, ArrayOrdering
from ..util import update_start_vals, get_default_varnames, get_var_name
from ..util import check_start_vals, update_start_vals, get_default_varnames, get_var_name

import warnings
from inspect import getargspec
@@ -89,13 +89,7 @@ def find_MAP(
else:
update_start_vals(start, model.test_point, model)

if not set(start.keys()).issubset(model.named_vars.keys()):
extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys()))
valid_keys = ", ".join(model.named_vars.keys())
raise KeyError(
"Some start parameters do not appear in the model!\n"
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
)
check_start_vals(start, model)

if vars is None:
vars = model.cont_vars
53 changes: 48 additions & 5 deletions pymc3/util.py
Original file line number Diff line number Diff line change
@@ -16,10 +16,11 @@
import functools
from typing import List, Dict, Tuple, Union

import numpy as np
import xarray
import arviz
from numpy import ndarray

from pymc3.exceptions import SamplingError
from theano.tensor import TensorVariable


@@ -188,6 +189,48 @@ def update_start_vals(a, b, model):
a.update({k: v for k, v in b.items() if k not in a})


def check_start_vals(start, model):
r"""Check that the starting values for MCMC do not cause the relevant log probability
to evaluate to something invalid (e.g. Inf or NaN)

Parameters
----------
start : dict, or array of dict
Starting point in parameter space (or partial point)
Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
overwrite the default.
model : Model object
Raises
______
KeyError if the parameters provided by `start` do not agree with the parameters contained
within `model`
pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an
invalid (i.e. non-finite) state
Returns
-------
None
"""
start_points = [start] if isinstance(start, dict) else start
for elem in start_points:
if not set(elem.keys()).issubset(model.named_vars.keys()):
extra_keys = ", ".join(set(elem.keys()) - set(model.named_vars.keys()))
valid_keys = ", ".join(model.named_vars.keys())
raise KeyError(
"Some start parameters do not appear in the model!\n"
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
)

initial_eval = model.check_test_point(test_point=elem)

if not np.all(np.isfinite(initial_eval)):
raise SamplingError(
"Initial evaluation of model at starting point failed!\n"
"Starting values:\n{}\n\n"
"Initial evaluation results:\n{}".format(elem, str(initial_eval))
)


def get_transformed(z):
if hasattr(z, "transformed"):
z = z.transformed
@@ -214,13 +257,13 @@ def enhanced(*args, **kwargs):

# FIXME: this function is poorly named, because it returns a LIST of
# points, not a dictionary of points.
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
# grab posterior samples for each variable
_samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()}
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
# make dicts
points: List[Dict[str, ndarray]] = []
points: List[Dict[str, np.ndarray]] = []
vn: str
s: ndarray
s: np.ndarray
for c in ds.chain:
for d in ds.draw:
points.append({vn: s[c, d] for vn, s in _samples.items()})