@@ -510,8 +510,8 @@ class DEMetropolis(PopulationArrayStepShared):
510510 S (and n). Defaults to Uniform(-S,+S).
511511 scaling : scalar or array
512512 Initial scale factor for epsilon. Defaults to 0.001
513- tune : bool
514- Flag for tuning the scaling . Defaults to True .
513+ tune : str
514+ Which hyperparameter to tune . Defaults to None, but can also be 'scaling' or 'lambda' .
515515 tune_interval : int
516516 The frequency of tuning. Defaults to 100 iterations.
517517 model : PyMC Model
@@ -536,10 +536,11 @@ class DEMetropolis(PopulationArrayStepShared):
536536 'accepted' : np .bool ,
537537 'tune' : np .bool ,
538538 'scaling' : np .float64 ,
539+ 'lambda' : np .float64 ,
539540 }]
540541
541542 def __init__ (self , vars = None , S = None , proposal_dist = None , lamb = None , scaling = 0.001 ,
542- tune = True , tune_interval = 100 , model = None , mode = None , ** kwargs ):
543+ tune = None , tune_interval = 100 , model = None , mode = None , ** kwargs ):
543544
544545 model = pm .modelcontext (model )
545546
@@ -549,7 +550,7 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
549550
550551 if S is None :
551552 S = np .ones (model .ndim )
552-
553+
553554 if proposal_dist is not None :
554555 self .proposal_dist = proposal_dist (S )
555556 else :
@@ -559,6 +560,8 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
559560 if lamb is None :
560561 lamb = 2.38 / np .sqrt (2 * model .ndim )
561562 self .lamb = float (lamb )
563+ if not tune in {None , 'scaling' , 'lambda' }:
564+ raise ValueError ('The parameter "tune" must be one of {None, scaling, lambda}' )
562565 self .tune = tune
563566 self .tune_interval = tune_interval
564567 self .steps_until_tune = tune_interval
@@ -572,9 +575,10 @@ def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.0
572575
573576 def astep (self , q0 ):
574577 if not self .steps_until_tune and self .tune :
575- # Tune scaling parameter
576- self .scaling = tune (
577- self .scaling , self .accepted / float (self .tune_interval ))
578+ if self .tune == 'scaling' :
579+ self .scaling = tune (self .scaling , self .accepted / float (self .tune_interval ))
580+ elif self .tune == 'lambda' :
581+ self .lamb = tune (self .lamb , self .accepted / float (self .tune_interval ))
578582 # Reset counter
579583 self .steps_until_tune = self .tune_interval
580584 self .accepted = 0
@@ -598,6 +602,7 @@ def astep(self, q0):
598602 stats = {
599603 'tune' : self .tune ,
600604 'scaling' : self .scaling ,
605+ 'lambda' : self .lamb ,
601606 'accept' : np .exp (accept ),
602607 'accepted' : accepted
603608 }
0 commit comments