@@ -130,7 +130,7 @@ def __init__(
130
130
Parameters
131
131
----------
132
132
vars: list
133
- List of variables for sampler
133
+ List of value variables for sampler
134
134
S: standard deviation or covariance matrix
135
135
Some measure of variance to parameterize proposal distribution
136
136
proposal_dist: function
@@ -153,6 +153,8 @@ def __init__(
153
153
154
154
if vars is None :
155
155
vars = model .value_vars
156
+ else :
157
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
156
158
vars = pm .inputvars (vars )
157
159
158
160
if S is None :
@@ -288,7 +290,7 @@ class BinaryMetropolis(ArrayStep):
288
290
Parameters
289
291
----------
290
292
vars: list
291
- List of variables for sampler
293
+ List of value variables for sampler
292
294
scaling: scalar or array
293
295
Initial scale factor for proposal. Defaults to 1.
294
296
tune: bool
@@ -321,6 +323,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
321
323
self .steps_until_tune = tune_interval
322
324
self .accepted = 0
323
325
326
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
327
+
324
328
if not all ([v .dtype in pm .discrete_types for v in vars ]):
325
329
raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
326
330
@@ -388,7 +392,7 @@ class BinaryGibbsMetropolis(ArrayStep):
388
392
Parameters
389
393
----------
390
394
vars: list
391
- List of variables for sampler
395
+ List of value variables for sampler
392
396
order: list or 'random'
393
397
List of integers indicating the Gibbs update order
394
398
e.g., [0, 2, 1, ...]. Default is random
@@ -410,6 +414,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
410
414
self .transit_p = transit_p
411
415
412
416
initial_point = model .initial_point
417
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
413
418
self .dim = sum (initial_point [v .name ].size for v in vars )
414
419
415
420
if order == "random" :
@@ -490,6 +495,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
490
495
491
496
model = pm .modelcontext (model )
492
497
498
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
493
499
vars = pm .inputvars (vars )
494
500
495
501
initial_point = model .initial_point
@@ -697,6 +703,8 @@ def __init__(
697
703
698
704
if vars is None :
699
705
vars = model .cont_vars
706
+ else :
707
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
700
708
vars = pm .inputvars (vars )
701
709
702
710
if S is None :
@@ -846,6 +854,8 @@ def __init__(
846
854
847
855
if vars is None :
848
856
vars = model .cont_vars
857
+ else :
858
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
849
859
vars = pm .inputvars (vars )
850
860
851
861
if S is None :
0 commit comments