18
18
19
19
from pymc .aesaraf import floatX
20
20
from pymc .backends .report import SamplerWarning , WarningType
21
- from pymc .math import logbern , logdiffexp_numpy
21
+ from pymc .math import logbern
22
22
from pymc .step_methods .arraystep import Competence
23
23
from pymc .step_methods .hmc .base_hmc import BaseHMC , DivergenceInfo , HMCStepData
24
24
from pymc .step_methods .hmc .integration import IntegrationError
@@ -78,6 +78,12 @@ class NUTS(BaseHMC):
78
78
by the python standard library `time.perf_counter` (wall time).
79
79
- `perf_counter_start`: The value of `time.perf_counter` at the beginning
80
80
of the computation of the draw.
81
+ - `index_in_trajectory`: This is usually only interesting for debugging
82
+ purposes. This indicates the position of the posterior draw in the
83
+ trajectory. Eg a -4 would indicate that the draw was the result of the
84
+ fourth leapfrog step in negative direction.
85
+ - `largest_eigval` and `smallest_eigval`: Experimental statistics for
86
+ some mass matrix adaptation algorithms. This is nan if it is not used.
81
87
82
88
References
83
89
----------
@@ -105,6 +111,9 @@ class NUTS(BaseHMC):
105
111
"process_time_diff" : np .float64 ,
106
112
"perf_counter_diff" : np .float64 ,
107
113
"perf_counter_start" : np .float64 ,
114
+ "largest_eigval" : np .float64 ,
115
+ "smallest_eigval" : np .float64 ,
116
+ "index_in_trajectory" : np .int64 ,
108
117
}
109
118
]
110
119
@@ -219,12 +228,12 @@ def warnings(self):
219
228
220
229
221
230
# A proposal for the next position
222
- Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept_weighted, logp " )
231
+ Proposal = namedtuple ("Proposal" , "q, q_grad, energy, logp, index_in_trajectory " )
223
232
224
233
# A subtree of the binary tree built by nuts.
225
234
Subtree = namedtuple (
226
235
"Subtree" ,
227
- "left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals " ,
236
+ "left, right, p_sum, proposal, log_size" ,
228
237
)
229
238
230
239
@@ -252,10 +261,10 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
252
261
self .start_energy = np .array (start .energy )
253
262
254
263
self .left = self .right = start
255
- self .proposal = Proposal (start .q .data , start .q_grad , start .energy , 1.0 , start .model_logp )
264
+ self .proposal = Proposal (start .q .data , start .q_grad , start .energy , start .model_logp , 0 )
256
265
self .depth = 0
257
266
self .log_size = 0
258
- self .log_weighted_accept_sum = - np .inf
267
+ self .log_accept_sum = - np .inf
259
268
self .mean_tree_accept = 0.0
260
269
self .n_proposals = 0
261
270
self .p_sum = start .p .data .copy ()
@@ -279,7 +288,7 @@ def extend(self, direction):
279
288
)
280
289
leftmost_begin , leftmost_end = self .left , self .right
281
290
rightmost_begin , rightmost_end = tree .left , tree .right
282
- leftmost_p_sum = self .p_sum
291
+ leftmost_p_sum = self .p_sum . copy ()
283
292
rightmost_p_sum = tree .p_sum
284
293
self .right = tree .right
285
294
else :
@@ -289,11 +298,10 @@ def extend(self, direction):
289
298
leftmost_begin , leftmost_end = tree .right , tree .left
290
299
rightmost_begin , rightmost_end = self .left , self .right
291
300
leftmost_p_sum = tree .p_sum
292
- rightmost_p_sum = self .p_sum
301
+ rightmost_p_sum = self .p_sum . copy ()
293
302
self .left = tree .right
294
303
295
304
self .depth += 1
296
- self .n_proposals += tree .n_proposals
297
305
298
306
if diverging or turning :
299
307
return diverging , turning
@@ -303,9 +311,6 @@ def extend(self, direction):
303
311
self .proposal = tree .proposal
304
312
305
313
self .log_size = np .logaddexp (self .log_size , tree .log_size )
306
- self .log_weighted_accept_sum = np .logaddexp (
307
- self .log_weighted_accept_sum , tree .log_weighted_accept_sum
308
- )
309
314
self .p_sum [:] += tree .p_sum
310
315
311
316
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -336,30 +341,30 @@ def _single_step(self, left, epsilon):
336
341
if np .isnan (energy_change ):
337
342
energy_change = np .inf
338
343
344
+ self .log_accept_sum = np .logaddexp (self .log_accept_sum , min (0 , - energy_change ))
345
+
339
346
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
340
347
self .max_energy_change = energy_change
341
- if np . abs ( energy_change ) < self .Emax :
348
+ if energy_change < self .Emax :
342
349
# Acceptance statistic
343
350
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
344
351
# Saturated Metropolis accept probability with Boltzmann weight
345
- # if h - H0 < 0
346
- log_p_accept_weighted = - energy_change + min (0.0 , - energy_change )
347
352
log_size = - energy_change
348
353
proposal = Proposal (
349
354
right .q .data ,
350
355
right .q_grad ,
351
356
right .energy ,
352
- log_p_accept_weighted ,
353
357
right .model_logp ,
358
+ right .index_in_trajectory ,
354
359
)
355
- tree = Subtree (
356
- right , right , right .p .data , proposal , log_size , log_p_accept_weighted , 1
357
- )
360
+ tree = Subtree (right , right , right .p .data , proposal , log_size )
358
361
return tree , None , False
359
362
else :
360
363
error_msg = f"Energy change in leapfrog step is too large: { energy_change } ."
361
364
error = None
362
- tree = Subtree (None , None , None , None , - np .inf , - np .inf , 1 )
365
+ finally :
366
+ self .n_proposals += 1
367
+ tree = Subtree (None , None , None , None , - np .inf )
363
368
divergance_info = DivergenceInfo (error_msg , error , left , right )
364
369
return tree , divergance_info , False
365
370
@@ -387,31 +392,20 @@ def _build_subtree(self, left, depth, epsilon):
387
392
turning = turning | turning1 | turning2
388
393
389
394
log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
390
- log_weighted_accept_sum = np .logaddexp (
391
- tree1 .log_weighted_accept_sum , tree2 .log_weighted_accept_sum
392
- )
393
395
if logbern (tree2 .log_size - log_size ):
394
396
proposal = tree2 .proposal
395
397
else :
396
398
proposal = tree1 .proposal
397
399
else :
398
400
p_sum = tree1 .p_sum
399
401
log_size = tree1 .log_size
400
- log_weighted_accept_sum = tree1 .log_weighted_accept_sum
401
402
proposal = tree1 .proposal
402
403
403
- n_proposals = tree1 .n_proposals + tree2 .n_proposals
404
-
405
- tree = Subtree (left , right , p_sum , proposal , log_size , log_weighted_accept_sum , n_proposals )
404
+ tree = Subtree (left , right , p_sum , proposal , log_size )
406
405
return tree , diverging , turning
407
406
408
407
def stats (self ):
409
- # Update accept stat if any subtrees were accepted
410
- if self .log_size > 0 :
411
- # Remove contribution from initial state which is always a perfect
412
- # accept
413
- log_sum_weight = logdiffexp_numpy (self .log_size , 0.0 )
414
- self .mean_tree_accept = np .exp (self .log_weighted_accept_sum - log_sum_weight )
408
+ self .mean_tree_accept = np .exp (self .log_accept_sum ) / self .n_proposals
415
409
return {
416
410
"depth" : self .depth ,
417
411
"mean_tree_accept" : self .mean_tree_accept ,
@@ -420,4 +414,5 @@ def stats(self):
420
414
"tree_size" : self .n_proposals ,
421
415
"max_energy_error" : self .max_energy_change ,
422
416
"model_logp" : self .proposal .logp ,
417
+ "index_in_trajectory" : self .proposal .index_in_trajectory ,
423
418
}
0 commit comments