Skip to content

Commit a1920c6

Browse files
don't tune DEMetropolis by default
+ tune argument now one of None,scaling,lambda + support for tuning lambda (closes #3720) + added test to check checking of tune setting + both scaling and lambda are recorded in the sampler stats
1 parent 5d60f8c commit a1920c6

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

pymc3/step_methods/metropolis.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ class DEMetropolis(PopulationArrayStepShared):
510510
S (and n). Defaults to Uniform(-S,+S).
511511
scaling : scalar or array
512512
Initial scale factor for epsilon. Defaults to 0.001
513-
tune : bool
514-
Flag for tuning the scaling. Defaults to True.
513+
tune : str
514+
Which hyperparameter to tune. Defaults to None, but can also be 'scaling' or 'lambda'.
515515
tune_interval : int
516516
The frequency of tuning. Defaults to 100 iterations.
517517
model : PyMC Model
@@ -536,10 +536,11 @@ class DEMetropolis(PopulationArrayStepShared):
536536
'accepted': np.bool,
537537
'tune': np.bool,
538538
'scaling': np.float64,
539+
'lambda': np.float64,
539540
}]
540541

541542
def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.001,
542-
tune=True, tune_interval=100, model=None, mode=None, **kwargs):
543+
tune=None, tune_interval=100, model=None, mode=None, **kwargs):
543544

544545
model = pm.modelcontext(model)
545546

@@ -549,7 +550,7 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
549550

550551
if S is None:
551552
S = np.ones(model.ndim)
552-
553+
553554
if proposal_dist is not None:
554555
self.proposal_dist = proposal_dist(S)
555556
else:
@@ -559,6 +560,8 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
559560
if lamb is None:
560561
lamb = 2.38 / np.sqrt(2 * model.ndim)
561562
self.lamb = float(lamb)
563+
if not tune in {None, 'scaling', 'lambda'}:
564+
raise ValueError('The parameter "tune" must be one of {None, scaling, lambda}')
562565
self.tune = tune
563566
self.tune_interval = tune_interval
564567
self.steps_until_tune = tune_interval
@@ -572,9 +575,10 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
572575

573576
def astep(self, q0):
574577
if not self.steps_until_tune and self.tune:
575-
# Tune scaling parameter
576-
self.scaling = tune(
577-
self.scaling, self.accepted / float(self.tune_interval))
578+
if self.tune == 'scaling':
579+
self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval))
580+
elif self.tune == 'lambda':
581+
self.lamb = tune(self.lamb, self.accepted / float(self.tune_interval))
578582
# Reset counter
579583
self.steps_until_tune = self.tune_interval
580584
self.accepted = 0
@@ -598,6 +602,7 @@ def astep(self, q0):
598602
stats = {
599603
'tune': self.tune,
600604
'scaling': self.scaling,
605+
'lambda': self.lamb,
601606
'accept': np.exp(accept),
602607
'accepted': accepted
603608
}

pymc3/tests/test_step.py

+18
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,24 @@ def test_demcmc_warning_on_small_populations(self):
719719
)
720720
pass
721721

722+
def test_demcmc_tune_parameter(self):
723+
"""Tests that validity of the tune setting is checked"""
724+
with Model() as model:
725+
Normal("n", mu=0, sigma=1, shape=(2,3))
726+
727+
step = DEMetropolis()
728+
assert step.tune is None
729+
730+
step = DEMetropolis(tune='scaling')
731+
assert step.tune == 'scaling'
732+
733+
step = DEMetropolis(tune='lambda')
734+
assert step.tune == 'lambda'
735+
736+
with pytest.raises(ValueError):
737+
DEMetropolis(tune='foo')
738+
pass
739+
722740
def test_nonparallelized_chains_are_random(self):
723741
with Model() as model:
724742
x = Normal("x", 0, 1)

0 commit comments

Comments
 (0)