Skip to content

Commit 4f8ad5d

Browse files
authored
Improve tuning by skipping the first samples + add new experimental tuning method (#5004)
* Fix issue in hmc gradient storage * Skip first samples during NUTS adaptation * Add test and doc for jitter+adapt_diag_grad * Improve tests of init methods * Add new tuning method to release notes * Remove old gradient mass matrix adaptation * Remove weight argument in quadpotential add_sample
1 parent bcc40ce commit 4f8ad5d

File tree

6 files changed

+200
-104
lines changed

6 files changed

+200
-104
lines changed

RELEASE-NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
2424
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
2525
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
26+
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
27+
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
2628
- ...
2729

2830
### Maintenance

pymc3/sampling.py

+23-39
Original file line numberDiff line numberDiff line change
@@ -287,25 +287,7 @@ def sample(
287287
by default. See ``discard_tuned_samples``.
288288
init : str
289289
Initialization method to use for auto-assigned NUTS samplers.
290-
291-
* auto: Choose a default initialization method automatically.
292-
Currently, this is ``jitter+adapt_diag``, but this can change in the future.
293-
If you depend on the exact behaviour, choose an initialization method explicitly.
294-
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
295-
variance of the tuning samples. All chains use the test value (usually the prior mean)
296-
as starting point.
297-
* jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the
298-
starting point in each chain.
299-
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
300-
sample variance of the tuning samples.
301-
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
302-
on the variance of the gradients during tuning. This is **experimental** and might be
303-
removed in a future release.
304-
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
305-
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
306-
* map: Use the MAP as starting point. This is discouraged.
307-
* adapt_full: Adapt a dense mass matrix using the sample covariances
308-
290+
See `pm.init_nuts` for a list of all options.
309291
step : function or iterable of functions
310292
A step function or collection of functions. If there are variables without step methods,
311293
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -516,6 +498,7 @@ def sample(
516498
random_seed=random_seed,
517499
progressbar=progressbar,
518500
jitter_max_retries=jitter_max_retries,
501+
tune=tune,
519502
**kwargs,
520503
)
521504
if start is None:
@@ -2078,6 +2061,7 @@ def init_nuts(
20782061
random_seed=None,
20792062
progressbar=True,
20802063
jitter_max_retries=10,
2064+
tune=None,
20812065
**kwargs,
20822066
):
20832067
"""Set up the mass matrix initialization for NUTS.
@@ -2099,11 +2083,11 @@ def init_nuts(
20992083
as starting point.
21002084
* jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
21012085
[-1, 1] as starting point in each chain.
2086+
* jitter+adapt_diag_grad:
2087+
An experimental initialization method that uses information from gradients and samples
2088+
during tuning.
21022089
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
21032090
sample variance of the tuning samples.
2104-
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
2105-
on the variance of the gradients during tuning. This is **experimental** and might be
2106-
removed in a future release.
21072091
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
21082092
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
21092093
* map: Use the MAP as starting point. This is discouraged.
@@ -2174,24 +2158,24 @@ def init_nuts(
21742158
var = np.ones_like(mean)
21752159
n = len(var)
21762160
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
2177-
elif init == "advi+adapt_diag_grad":
2178-
approx: pm.MeanField = pm.fit(
2179-
random_seed=random_seed,
2180-
n=n_init,
2181-
method="advi",
2182-
model=model,
2183-
callbacks=cb,
2184-
progressbar=progressbar,
2185-
obj_optimizer=pm.adagrad_window,
2161+
elif init == "jitter+adapt_diag_grad":
2162+
start = _init_jitter(model, model.initial_point, chains, jitter_max_retries)
2163+
mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0)
2164+
var = np.ones_like(mean)
2165+
n = len(var)
2166+
2167+
if tune is not None and tune > 250:
2168+
stop_adaptation = tune - 50
2169+
else:
2170+
stop_adaptation = None
2171+
2172+
potential = quadpotential.QuadPotentialDiagAdaptExp(
2173+
n,
2174+
mean,
2175+
alpha=0.02,
2176+
use_grads=True,
2177+
stop_adaptation=stop_adaptation,
21862178
)
2187-
start = approx.sample(draws=chains)
2188-
start = list(start)
2189-
std_apoint = approx.std.eval()
2190-
cov = std_apoint ** 2
2191-
mean = approx.mean.get_value()
2192-
weight = 50
2193-
n = len(cov)
2194-
potential = quadpotential.QuadPotentialDiagAdaptGrad(n, mean, cov, weight)
21952179
elif init == "advi+adapt_diag":
21962180
approx = pm.fit(
21972181
random_seed=random_seed,

pymc3/step_methods/hmc/nuts.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
253253
self.start_energy = np.array(start.energy)
254254

255255
self.left = self.right = start
256-
self.proposal = Proposal(
257-
start.q.data, start.q_grad.data, start.energy, 1.0, start.model_logp
258-
)
256+
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
259257
self.depth = 0
260258
self.log_size = 0
261259
self.log_weighted_accept_sum = -np.inf
@@ -350,7 +348,7 @@ def _single_step(self, left, epsilon):
350348
log_size = -energy_change
351349
proposal = Proposal(
352350
right.q.data,
353-
right.q_grad.data,
351+
right.q_grad,
354352
right.energy,
355353
log_p_accept_weighted,
356354
right.model_logp,

0 commit comments

Comments
 (0)