Skip to content

Commit 580a32a

Browse files
ricardoV94twiecki
andauthored
Implement more robust jitter init (resolves #4107) (#4298)
* Implement robust jitter for init_nuts * Update Release notes * Small documentation formatting * Update pymc3/sampling.py Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 3fa3d1f commit 580a32a

File tree

3 files changed

+76
-14
lines changed

3 files changed

+76
-14
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
4848
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
4949
- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207)
5050
- The `InverseGamma` distribution now implements a `logcdf`. [#3944](https://github.com/pymc-devs/pymc3/pull/3944)
51+
- Make starting jitter methods for nuts sampling more robust by resampling values that lead to non-finite probabilities. A new optional argument `jitter-max-retries` can be passed to `pm.sample()` and `pm.init_nuts()` to control the maximum number of retries per chain. [4298](https://github.com/pymc-devs/pymc3/pull/4298)
5152

5253
### Documentation
5354
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).

pymc3/sampling.py

+53-14
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
chains_and_samples,
6161
)
6262
from .vartypes import discrete_types
63-
from .exceptions import IncorrectArgumentsError
63+
from .exceptions import IncorrectArgumentsError, SamplingError
6464
from .parallel_sampling import _cpu_count, Draw
6565
from pymc3.step_methods.hmc import quadpotential
6666
import pymc3 as pm
@@ -246,6 +246,7 @@ def sample(
246246
discard_tuned_samples=True,
247247
compute_convergence_checks=True,
248248
callback=None,
249+
jitter_max_retries=10,
249250
*,
250251
return_inferencedata=None,
251252
idata_kwargs: dict = None,
@@ -331,6 +332,10 @@ def sample(
331332
the ``draw.chain`` argument can be used to determine which of the active chains the sample
332333
is drawn from.
333334
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
335+
jitter_max_retries : int
336+
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
337+
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
338+
init methods.
334339
return_inferencedata : bool, default=False
335340
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
336341
Defaults to `False`, but we'll switch to `True` in an upcoming release.
@@ -490,6 +495,7 @@ def sample(
490495
model=model,
491496
random_seed=random_seed,
492497
progressbar=progressbar,
498+
jitter_max_retries=jitter_max_retries,
493499
**kwargs,
494500
)
495501
if start is None:
@@ -1946,13 +1952,52 @@ def sample_prior_predictive(
19461952
return prior
19471953

19481954

1955+
def _init_jitter(model, chains, jitter_max_retries):
1956+
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
1957+
1958+
pymc3.util.check_start_vals is used to test whether the jittered starting values produce
1959+
a finite log probability. Invalid values are resampled unless `jitter_max_retries` is achieved,
1960+
in which case the last sampled values are returned.
1961+
1962+
Parameters
1963+
----------
1964+
model : pymc3.Model
1965+
chains : int
1966+
jitter_max_retries : int
1967+
Maximum number of repeated attempts at initializing values (per chain).
1968+
1969+
Returns
1970+
-------
1971+
start : ``pymc3.model.Point``
1972+
Starting point for sampler
1973+
"""
1974+
start = []
1975+
for _ in range(chains):
1976+
for i in range(jitter_max_retries + 1):
1977+
mean = {var: val.copy() for var, val in model.test_point.items()}
1978+
for val in mean.values():
1979+
val[...] += 2 * np.random.rand(*val.shape) - 1
1980+
1981+
if i < jitter_max_retries:
1982+
try:
1983+
check_start_vals(mean, model)
1984+
except SamplingError:
1985+
pass
1986+
else:
1987+
break
1988+
1989+
start.append(mean)
1990+
return start
1991+
1992+
19491993
def init_nuts(
19501994
init="auto",
19511995
chains=1,
19521996
n_init=500000,
19531997
model=None,
19541998
random_seed=None,
19551999
progressbar=True,
2000+
jitter_max_retries=10,
19562001
**kwargs,
19572002
):
19582003
"""Set up the mass matrix initialization for NUTS.
@@ -1967,7 +2012,7 @@ def init_nuts(
19672012
Initialization method to use.
19682013
19692014
* auto: Choose a default initialization method automatically.
1970-
Currently, this is `'jitter+adapt_diag'`, but this can change in the future. If you
2015+
Currently, this is ``jitter+adapt_diag``, but this can change in the future. If you
19712016
depend on the exact behaviour, choose an initialization method explicitly.
19722017
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
19732018
variance of the tuning samples. All chains use the test value (usually the prior mean)
@@ -1994,6 +2039,10 @@ def init_nuts(
19942039
model : Model (optional if in ``with`` context)
19952040
progressbar : bool
19962041
Whether or not to display a progressbar for advi sampling.
2042+
jitter_max_retries : int
2043+
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
2044+
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
2045+
init methods.
19972046
**kwargs : keyword arguments
19982047
Extra keyword arguments are forwarded to pymc3.NUTS.
19992048
@@ -2038,12 +2087,7 @@ def init_nuts(
20382087
var = np.ones_like(mean)
20392088
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
20402089
elif init == "jitter+adapt_diag":
2041-
start = []
2042-
for _ in range(chains):
2043-
mean = {var: val.copy() for var, val in model.test_point.items()}
2044-
for val in mean.values():
2045-
val[...] += 2 * np.random.rand(*val.shape) - 1
2046-
start.append(mean)
2090+
start = _init_jitter(model, chains, jitter_max_retries)
20472091
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
20482092
var = np.ones_like(mean)
20492093
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
@@ -2125,12 +2169,7 @@ def init_nuts(
21252169
cov = np.eye(model.ndim)
21262170
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
21272171
elif init == "jitter+adapt_full":
2128-
start = []
2129-
for _ in range(chains):
2130-
mean = {var: val.copy() for var, val in model.test_point.items()}
2131-
for val in mean.values():
2132-
val[...] += 2 * np.random.rand(*val.shape) - 1
2133-
start.append(mean)
2172+
start = _init_jitter(model, chains, jitter_max_retries)
21342173
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
21352174
cov = np.eye(model.ndim)
21362175
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)

pymc3/tests/test_sampling.py

+22
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,28 @@ def _mocked_init_nuts(*args, **kwargs):
815815
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)
816816

817817

818+
@pytest.mark.parametrize(
819+
"testval, jitter_max_retries, expectation",
820+
[
821+
(0, 0, pytest.raises(SamplingError)),
822+
(0, 1, pytest.raises(SamplingError)),
823+
(0, 4, does_not_raise()),
824+
(0, 10, does_not_raise()),
825+
(1, 0, does_not_raise()),
826+
],
827+
)
828+
def test_init_jitter(testval, jitter_max_retries, expectation):
829+
with pm.Model() as m:
830+
pm.HalfNormal("x", transform=None, testval=testval)
831+
832+
with expectation:
833+
# Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1)
834+
# and positive (valid) when it returns 1 (jitter = 1)
835+
with mock.patch("numpy.random.rand", side_effect=[0, 0, 0, 1, 0]):
836+
start = pm.sampling._init_jitter(m, chains=1, jitter_max_retries=jitter_max_retries)
837+
pm.util.check_start_vals(start, m)
838+
839+
818840
@pytest.fixture(scope="class")
819841
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
820842
with pm.Model() as pmodel:

0 commit comments

Comments
 (0)