Skip to content

Commit 2295f0b

Browse files
ColCarrolltwiecki
authored andcommitted
Clean up HMC code (#3447)
* Clean up HMC code * Fix some tests * Add back target_accept attribute
1 parent fbb864f commit 2295f0b

File tree

5 files changed

+19
-36
lines changed

5 files changed

+19
-36
lines changed

pymc3/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def __call__(self, array, grad_out=None, extra_vars=None):
492492
if grad_out is None:
493493
return logp, dlogp
494494
else:
495-
out[...] = dlogp
495+
np.copyto(out, dlogp)
496496
return logp
497497

498498
@property
@@ -737,7 +737,7 @@ def logpt(self):
737737
def logp_nojact(self):
738738
"""Theano scalar of log-probability of the model but without the jacobian
739739
if transformed Random Variable is presented.
740-
Note that If there is no transformed variable in the model, logp_nojact
740+
Note that If there is no transformed variable in the model, logp_nojact
741741
will be the same as logpt as there is no need for Jacobian correction.
742742
"""
743743
with self:

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state")
2121

22-
2322
class BaseHMC(arraystep.GradientSharedStep):
2423
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
2524

@@ -34,7 +33,6 @@ def __init__(
3433
model=None,
3534
blocked=True,
3635
potential=None,
37-
integrator="leapfrog",
3836
dtype=None,
3937
Emax=1000,
4038
target_accept=0.8,
@@ -79,11 +77,10 @@ def __init__(
7977
size = self._logp_dlogp_func.size
8078

8179
self.step_size = step_scale / (size ** 0.25)
82-
self.target_accept = target_accept
8380
self.step_adapt = step_sizes.DualAverageAdaptation(
8481
self.step_size, target_accept, gamma, k, t0
8582
)
86-
83+
self.target_accept = target_accept
8784
self.tune = True
8885

8986
if scaling is None and potential is None:

pymc3/step_methods/hmc/integration.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def compute_state(self, q, p):
3232
energy = kinetic - logp
3333
return State(q, p, v, dlogp, energy, logp)
3434

35-
def step(self, epsilon, state, out=None):
35+
def step(self, epsilon, state):
3636
"""Leapfrog integrator step.
3737
3838
Half a momentum update, full position update, half momentum update.
@@ -51,7 +51,7 @@ def step(self, epsilon, state, out=None):
5151
None if `out` is provided, else a State namedtuple
5252
"""
5353
try:
54-
return self._step(epsilon, state, out=None)
54+
return self._step(epsilon, state)
5555
except linalg.LinAlgError as err:
5656
msg = "LinAlgError during leapfrog step."
5757
raise IntegrationError(msg)
@@ -64,26 +64,20 @@ def step(self, epsilon, state, out=None):
6464
else:
6565
raise
6666

67-
def _step(self, epsilon, state, out=None):
68-
pot = self._potential
67+
def _step(self, epsilon, state):
6968
axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype)
69+
pot = self._potential
7070

71-
q, p, v, q_grad, energy, logp = state
72-
if out is None:
73-
q_new = q.copy()
74-
p_new = p.copy()
75-
v_new = np.empty_like(q)
76-
q_new_grad = np.empty_like(q)
77-
else:
78-
q_new, p_new, v_new, q_new_grad, energy = out
79-
q_new[:] = q
80-
p_new[:] = p
71+
q_new = state.q.copy()
72+
p_new = state.p.copy()
73+
v_new = np.empty_like(q_new)
74+
q_new_grad = np.empty_like(q_new)
8175

8276
dt = 0.5 * epsilon
8377

8478
# p is already stored in p_new
8579
# p_new = p + dt * q_grad
86-
axpy(q_grad, p_new, a=dt)
80+
axpy(state.q_grad, p_new, a=dt)
8781

8882
pot.velocity(p_new, out=v_new)
8983
# q is already stored in q_new
@@ -98,8 +92,4 @@ def _step(self, epsilon, state, out=None):
9892
kinetic = pot.velocity_energy(p_new, v_new)
9993
energy = kinetic - logp
10094

101-
if out is not None:
102-
out.energy = energy
103-
return
104-
else:
105-
return State(q_new, p_new, v_new, q_new_grad, energy, logp)
95+
return State(q_new, p_new, v_new, q_new_grad, energy, logp)

pymc3/tests/test_distributions_timeseries.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from ..theanof import floatX
66

77
import numpy as np
8+
import pytest
9+
10+
pytestmark = pytest.mark.usefixtures('seeded_test')
11+
812

913
def test_AR():
1014
# AR1

pymc3/tests/test_posteriors.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,11 @@ class TestSliceUniform(sf.SliceFixture, sf.UniformFixture):
3434

3535

3636
class TestNUTSUniform2(TestNUTSUniform):
37-
step_args = {'target_accept': 0.95, 'integrator': 'two-stage'}
37+
step_args = {'target_accept': 0.95}
3838

3939

4040
class TestNUTSUniform3(TestNUTSUniform):
41-
step_args = {'target_accept': 0.80, 'integrator': 'two-stage'}
42-
43-
44-
class TestNUTSUniform4(TestNUTSUniform):
45-
step_args = {'target_accept': 0.95, 'integrator': 'three-stage'}
46-
47-
48-
class TestNUTSUniform5(TestNUTSUniform):
49-
step_args = {'target_accept': 0.80, 'integrator': 'three-stage'}
41+
step_args = {'target_accept': 0.80}
5042

5143

5244
class TestNUTSNormal(sf.NutsFixture, sf.NormalFixture):

0 commit comments

Comments
 (0)