Skip to content

Commit c120b02

Browse files
authored
Small improvements to early nuts behaviour (#5824)
* Use all leapfrog steps for acceptance rate * Only count positive energy errors as divergence * Copy momentum sum to be sure
1 parent f602486 commit c120b02

File tree

6 files changed

+58
-36
lines changed

6 files changed

+58
-36
lines changed

pymc/step_methods/hmc/base_hmc.py

+1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def astep(self, q0):
236236

237237
stats.update(hmc_step.stats)
238238
stats.update(self.step_adapt.stats())
239+
stats.update(self.potential.stats())
239240

240241
return hmc_step.end.q, [stats]
241242

pymc/step_methods/hmc/hmc.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class HamiltonianMC(BaseHMC):
5151
"process_time_diff": np.float64,
5252
"perf_counter_diff": np.float64,
5353
"perf_counter_start": np.float64,
54+
"largest_eigval": np.float64,
55+
"smallest_eigval": np.float64,
5456
}
5557
]
5658

@@ -162,7 +164,16 @@ def _hamiltonian_step(self, start, p0, step_size):
162164
"model_logp": state.model_logp,
163165
}
164166
# Retrieve State q and p data from respective RaveledVars
165-
end = State(end.q.data, end.p.data, end.v, end.q_grad, end.energy, end.model_logp)
167+
end = State(
168+
end.q.data,
169+
end.p.data,
170+
end.v,
171+
end.q_grad,
172+
end.energy,
173+
end.model_logp,
174+
end.index_in_trajectory,
175+
)
176+
stats.update(self.potential.stats())
166177
return HMCStepData(end, accept_stat, div_info, stats)
167178

168179
@staticmethod

pymc/step_methods/hmc/integration.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pymc.blocking import RaveledVars
2222

23-
State = namedtuple("State", "q, p, v, q_grad, energy, model_logp")
23+
State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory")
2424

2525

2626
class IntegrationError(RuntimeError):
@@ -49,7 +49,7 @@ def compute_state(self, q, p):
4949
v = self._potential.velocity(p.data)
5050
kinetic = self._potential.energy(p.data, velocity=v)
5151
energy = kinetic - logp
52-
return State(q, p, v, dlogp, energy, logp)
52+
return State(q, p, v, dlogp, energy, logp, 0)
5353

5454
def step(self, epsilon, state):
5555
"""Leapfrog integrator step.
@@ -114,4 +114,12 @@ def _step(self, epsilon, state):
114114
kinetic = pot.velocity_energy(p_new.data, v_new)
115115
energy = kinetic - logp
116116

117-
return State(q_new, p_new, v_new, q_new_grad, energy, logp)
117+
return State(
118+
q_new,
119+
p_new,
120+
v_new,
121+
q_new_grad,
122+
energy,
123+
logp,
124+
state.index_in_trajectory + int(np.sign(epsilon)),
125+
)

pymc/step_methods/hmc/nuts.py

+27-32
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pymc.aesaraf import floatX
2020
from pymc.backends.report import SamplerWarning, WarningType
21-
from pymc.math import logbern, logdiffexp_numpy
21+
from pymc.math import logbern
2222
from pymc.step_methods.arraystep import Competence
2323
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
2424
from pymc.step_methods.hmc.integration import IntegrationError
@@ -78,6 +78,12 @@ class NUTS(BaseHMC):
7878
by the python standard library `time.perf_counter` (wall time).
7979
- `perf_counter_start`: The value of `time.perf_counter` at the beginning
8080
of the computation of the draw.
81+
- `index_in_trajectory`: This is usually only interesting for debugging
82+
purposes. This indicates the position of the posterior draw in the
83+
trajectory. Eg a -4 would indicate that the draw was the result of the
84+
fourth leapfrog step in negative direction.
85+
- `largest_eigval` and `smallest_eigval`: Experimental statistics for
86+
some mass matrix adaptation algorithms. This is nan if it is not used.
8187
8288
References
8389
----------
@@ -105,6 +111,9 @@ class NUTS(BaseHMC):
105111
"process_time_diff": np.float64,
106112
"perf_counter_diff": np.float64,
107113
"perf_counter_start": np.float64,
114+
"largest_eigval": np.float64,
115+
"smallest_eigval": np.float64,
116+
"index_in_trajectory": np.int64,
108117
}
109118
]
110119

@@ -219,12 +228,12 @@ def warnings(self):
219228

220229

221230
# A proposal for the next position
222-
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept_weighted, logp")
231+
Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory")
223232

224233
# A subtree of the binary tree built by nuts.
225234
Subtree = namedtuple(
226235
"Subtree",
227-
"left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals",
236+
"left, right, p_sum, proposal, log_size",
228237
)
229238

230239

@@ -252,10 +261,10 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
252261
self.start_energy = np.array(start.energy)
253262

254263
self.left = self.right = start
255-
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
264+
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
256265
self.depth = 0
257266
self.log_size = 0
258-
self.log_weighted_accept_sum = -np.inf
267+
self.log_accept_sum = -np.inf
259268
self.mean_tree_accept = 0.0
260269
self.n_proposals = 0
261270
self.p_sum = start.p.data.copy()
@@ -279,7 +288,7 @@ def extend(self, direction):
279288
)
280289
leftmost_begin, leftmost_end = self.left, self.right
281290
rightmost_begin, rightmost_end = tree.left, tree.right
282-
leftmost_p_sum = self.p_sum
291+
leftmost_p_sum = self.p_sum.copy()
283292
rightmost_p_sum = tree.p_sum
284293
self.right = tree.right
285294
else:
@@ -289,11 +298,10 @@ def extend(self, direction):
289298
leftmost_begin, leftmost_end = tree.right, tree.left
290299
rightmost_begin, rightmost_end = self.left, self.right
291300
leftmost_p_sum = tree.p_sum
292-
rightmost_p_sum = self.p_sum
301+
rightmost_p_sum = self.p_sum.copy()
293302
self.left = tree.right
294303

295304
self.depth += 1
296-
self.n_proposals += tree.n_proposals
297305

298306
if diverging or turning:
299307
return diverging, turning
@@ -303,9 +311,6 @@ def extend(self, direction):
303311
self.proposal = tree.proposal
304312

305313
self.log_size = np.logaddexp(self.log_size, tree.log_size)
306-
self.log_weighted_accept_sum = np.logaddexp(
307-
self.log_weighted_accept_sum, tree.log_weighted_accept_sum
308-
)
309314
self.p_sum[:] += tree.p_sum
310315

311316
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -336,30 +341,30 @@ def _single_step(self, left, epsilon):
336341
if np.isnan(energy_change):
337342
energy_change = np.inf
338343

344+
self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))
345+
339346
if np.abs(energy_change) > np.abs(self.max_energy_change):
340347
self.max_energy_change = energy_change
341-
if np.abs(energy_change) < self.Emax:
348+
if energy_change < self.Emax:
342349
# Acceptance statistic
343350
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
344351
# Saturated Metropolis accept probability with Boltzmann weight
345-
# if h - H0 < 0
346-
log_p_accept_weighted = -energy_change + min(0.0, -energy_change)
347352
log_size = -energy_change
348353
proposal = Proposal(
349354
right.q.data,
350355
right.q_grad,
351356
right.energy,
352-
log_p_accept_weighted,
353357
right.model_logp,
358+
right.index_in_trajectory,
354359
)
355-
tree = Subtree(
356-
right, right, right.p.data, proposal, log_size, log_p_accept_weighted, 1
357-
)
360+
tree = Subtree(right, right, right.p.data, proposal, log_size)
358361
return tree, None, False
359362
else:
360363
error_msg = f"Energy change in leapfrog step is too large: {energy_change}."
361364
error = None
362-
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
365+
finally:
366+
self.n_proposals += 1
367+
tree = Subtree(None, None, None, None, -np.inf)
363368
divergance_info = DivergenceInfo(error_msg, error, left, right)
364369
return tree, divergance_info, False
365370

@@ -387,31 +392,20 @@ def _build_subtree(self, left, depth, epsilon):
387392
turning = turning | turning1 | turning2
388393

389394
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
390-
log_weighted_accept_sum = np.logaddexp(
391-
tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum
392-
)
393395
if logbern(tree2.log_size - log_size):
394396
proposal = tree2.proposal
395397
else:
396398
proposal = tree1.proposal
397399
else:
398400
p_sum = tree1.p_sum
399401
log_size = tree1.log_size
400-
log_weighted_accept_sum = tree1.log_weighted_accept_sum
401402
proposal = tree1.proposal
402403

403-
n_proposals = tree1.n_proposals + tree2.n_proposals
404-
405-
tree = Subtree(left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals)
404+
tree = Subtree(left, right, p_sum, proposal, log_size)
406405
return tree, diverging, turning
407406

408407
def stats(self):
409-
# Update accept stat if any subtrees were accepted
410-
if self.log_size > 0:
411-
# Remove contribution from initial state which is always a perfect
412-
# accept
413-
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
414-
self.mean_tree_accept = np.exp(self.log_weighted_accept_sum - log_sum_weight)
408+
self.mean_tree_accept = np.exp(self.log_accept_sum) / self.n_proposals
415409
return {
416410
"depth": self.depth,
417411
"mean_tree_accept": self.mean_tree_accept,
@@ -420,4 +414,5 @@ def stats(self):
420414
"tree_size": self.n_proposals,
421415
"max_energy_error": self.max_energy_change,
422416
"model_logp": self.proposal.logp,
417+
"index_in_trajectory": self.proposal.index_in_trajectory,
423418
}

pymc/step_methods/hmc/quadpotential.py

+4
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def raise_ok(self, map_info=None):
136136
def reset(self):
137137
pass
138138

139+
def stats(self):
140+
return {"largest_eigval": np.nan, "smallest_eigval": np.nan}
141+
139142

140143
def isquadpotential(value):
141144
"""Check whether an object might be a QuadPotential object."""
@@ -254,6 +257,7 @@ def random(self):
254257

255258
def _update_from_weightvar(self, weightvar):
256259
weightvar.current_variance(out=self._var)
260+
self._var = np.clip(self._var, 1e-12, 1e12)
257261
np.sqrt(self._var, out=self._stds)
258262
np.divide(1, self._stds, out=self._inv_stds)
259263
self._var_aesara.set_value(self._var)

pymc/tests/test_step.py

+3
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,9 @@ def test_sampler_stats(self):
641641
"perf_counter_diff",
642642
"perf_counter_start",
643643
"process_time_diff",
644+
"index_in_trajectory",
645+
"largest_eigval",
646+
"smallest_eigval",
644647
}
645648
assert trace.stat_names == expected_stat_names
646649
for varname in trace.stat_names:

0 commit comments

Comments
 (0)