diff --git a/pymc3/step_methods/hmc/quadpotential.py b/pymc3/step_methods/hmc/quadpotential.py index 65ae08d969..886dd5f98b 100644 --- a/pymc3/step_methods/hmc/quadpotential.py +++ b/pymc3/step_methods/hmc/quadpotential.py @@ -126,8 +126,16 @@ def isquadpotential(value): class QuadPotentialDiagAdapt(QuadPotential): """Adapt a diagonal mass matrix from the sample variances.""" - def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0, - adaptation_window=101, dtype=None): + def __init__( + self, + n, + initial_mean, + initial_diag=None, + initial_weight=0, + adaptation_window=101, + adaptation_window_multiplier=1, + dtype=None, + ): """Set up a diagonal mass matrix.""" if initial_diag is not None and initial_diag.ndim != 1: raise ValueError('Initial diagonal must be one-dimensional.') @@ -158,6 +166,7 @@ def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0, self._background_var = _WeightedVariance(self._n, dtype=self.dtype) self._n_samples = 0 self.adaptation_window = adaptation_window + self.adaptation_window_multiplier = float(adaptation_window_multiplier) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -190,15 +199,14 @@ def update(self, sample, grad, tune): if not tune: return - window = self.adaptation_window - self._foreground_var.add_sample(sample, weight=1) self._background_var.add_sample(sample, weight=1) self._update_from_weightvar(self._foreground_var) - if self._n_samples > 0 and self._n_samples % window == 0: + if self._n_samples > 0 and self._n_samples % self.adaptation_window == 0: self._foreground_var = self._background_var self._background_var = _WeightedVariance(self._n, dtype=self.dtype) + self.adaptation_window = int(self.adaptation_window * self.adaptation_window_multiplier) self._n_samples += 1 @@ -458,13 +466,7 @@ def velocity_energy(self, x, v_out): class QuadPotentialFullAdapt(QuadPotentialFull): - """Adapt a dense mass matrix using the sample covariances - - If the parameter ``doubling`` is true, the adaptation window is doubled - every time it is passed. This can lead to better convergence of the mass - matrix estimation. - - """ + """Adapt a dense mass matrix using the sample covariances.""" def __init__( self, n, @@ -472,8 +474,8 @@ def __init__( initial_cov=None, initial_weight=0, adaptation_window=101, + adaptation_window_multiplier=2, update_window=1, - doubling=True, dtype=None, ): warnings.warn("QuadPotentialFullAdapt is an experimental feature") @@ -511,8 +513,8 @@ def __init__( self._background_cov = _WeightedCovariance(self._n, dtype=self.dtype) self._n_samples = 0 - self._doubling = doubling self._adaptation_window = int(adaptation_window) + self._adaptation_window_multiplier = float(adaptation_window_multiplier) self._update_window = int(update_window) self._previous_update = 0 @@ -538,7 +540,8 @@ def update(self, sample, grad, tune): if (delta + 1) % self._update_window == 0: self._update_from_weightvar(self._foreground_cov) - # Reset the background covariance if we are at the end of the adaptation window. + # Reset the background covariance if we are at the end of the adaptation + # window. if delta >= self._adaptation_window: self._foreground_cov = self._background_cov self._background_cov = _WeightedCovariance( @@ -546,8 +549,7 @@ def update(self, sample, grad, tune): ) self._previous_update = self._n_samples - if self._doubling: - self._adaptation_window *= 2 + self._adaptation_window = int(self._adaptation_window * self._adaptation_window_multiplier) self._n_samples += 1 diff --git a/pymc3/tests/test_quadpotential.py b/pymc3/tests/test_quadpotential.py index a22ece3f09..7b5e62050d 100644 --- a/pymc3/tests/test_quadpotential.py +++ b/pymc3/tests/test_quadpotential.py @@ -225,15 +225,15 @@ def test_full_adapt_adaptation_window(seed=8978): for i in range(window + 1): pot.update(np.random.randn(2), None, True) assert pot._previous_update == window - assert pot._adaptation_window == window * 2 + assert pot._adaptation_window == window * pot._adaptation_window_multiplier pot = quadpotential.QuadPotentialFullAdapt( - 2, np.zeros(2), np.eye(2), 1, adaptation_window=window, doubling=False + 2, np.zeros(2), np.eye(2), 1, adaptation_window=window ) for i in range(window + 1): pot.update(np.random.randn(2), None, True) assert pot._previous_update == window - assert pot._adaptation_window == window + assert pot._adaptation_window == window * pot._adaptation_window_multiplier def test_full_adapt_not_invertible():