Skip to content

Commit 5fcc3e5

Browse files
committed
Convert RVs to value vars in step methods
1 parent d926746 commit 5fcc3e5

File tree

11 files changed

+98
-15
lines changed

11 files changed

+98
-15
lines changed

pymc3/step_methods/arraystep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ArrayStep(BlockedStep):
134134
Parameters
135135
----------
136136
vars: list
137-
List of variables for sampler.
137+
List of value variables for sampler.
138138
fs: list of logp Aesara functions
139139
allvars: Boolean (default False)
140140
blocked: Boolean (default True)
@@ -190,7 +190,7 @@ def __init__(self, vars, shared, blocked=True):
190190
"""
191191
Parameters
192192
----------
193-
vars: list of sampling variables
193+
vars: list of sampling value variables
194194
shared: dict of Aesara variable -> shared variable
195195
blocked: Boolean (default True)
196196
"""
@@ -235,7 +235,7 @@ def __init__(self, vars, shared, blocked=True):
235235
"""
236236
Parameters
237237
----------
238-
vars: list of sampling variables
238+
vars: list of sampling value variables
239239
shared: dict of Aesara variable -> shared variable
240240
blocked: Boolean (default True)
241241
"""

pymc3/step_methods/elliptical_slice.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import numpy.random as nr
1818

19+
from pymc3.aesaraf import inputvars
1920
from pymc3.model import modelcontext
2021
from pymc3.step_methods.arraystep import ArrayStep, Competence
2122

@@ -61,7 +62,7 @@ class EllipticalSlice(ArrayStep):
6162
Parameters
6263
----------
6364
vars: list
64-
List of variables for sampler.
65+
List of value variables for sampler.
6566
prior_cov: array, optional
6667
Covariance matrix of the multivariate Gaussian prior.
6768
prior_chol: array, optional
@@ -88,6 +89,8 @@ def __init__(self, vars=None, prior_cov=None, prior_chol=None, model=None, **kwa
8889

8990
if vars is None:
9091
vars = self.model.cont_vars
92+
else:
93+
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
9194
vars = inputvars(vars)
9295

9396
super().__init__(vars, [self.model.fastlogp], **kwargs)

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __init__(
8989

9090
if vars is None:
9191
vars = self._model.cont_vars
92+
else:
93+
vars = [self._model.rvs_to_values.get(var, var) for var in vars]
9294

9395
super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)
9496

pymc3/step_methods/hmc/hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
6060
Parameters
6161
----------
6262
vars: list, default=None
63-
List of Aesara variables. If None, all continuous RVs from the
63+
List of value variables. If None, all continuous RVs from the
6464
model are included.
6565
path_length: float, default=2
6666
Total length to travel

pymc3/step_methods/hmc/nuts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs)
115115
Parameters
116116
----------
117117
vars: list, default=None
118-
List of Aesara variables. If None, all continuous RVs from the
118+
List of value variables. If None, all continuous RVs from the
119119
model are included.
120120
Emax: float, default 1000
121121
Maximum energy change allowed during leapfrog steps. Larger

pymc3/step_methods/metropolis.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
Parameters
131131
----------
132132
vars: list
133-
List of variables for sampler
133+
List of value variables for sampler
134134
S: standard deviation or covariance matrix
135135
Some measure of variance to parameterize proposal distribution
136136
proposal_dist: function
@@ -153,6 +153,8 @@ def __init__(
153153

154154
if vars is None:
155155
vars = model.value_vars
156+
else:
157+
vars = [model.rvs_to_values.get(var, var) for var in vars]
156158
vars = pm.inputvars(vars)
157159

158160
if S is None:
@@ -288,7 +290,7 @@ class BinaryMetropolis(ArrayStep):
288290
Parameters
289291
----------
290292
vars: list
291-
List of variables for sampler
293+
List of value variables for sampler
292294
scaling: scalar or array
293295
Initial scale factor for proposal. Defaults to 1.
294296
tune: bool
@@ -321,6 +323,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
321323
self.steps_until_tune = tune_interval
322324
self.accepted = 0
323325

326+
vars = [model.rvs_to_values.get(var, var) for var in vars]
327+
324328
if not all([v.dtype in pm.discrete_types for v in vars]):
325329
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
326330

@@ -388,7 +392,7 @@ class BinaryGibbsMetropolis(ArrayStep):
388392
Parameters
389393
----------
390394
vars: list
391-
List of variables for sampler
395+
List of value variables for sampler
392396
order: list or 'random'
393397
List of integers indicating the Gibbs update order
394398
e.g., [0, 2, 1, ...]. Default is random
@@ -410,6 +414,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
410414
self.transit_p = transit_p
411415

412416
initial_point = model.initial_point
417+
vars = [model.rvs_to_values.get(var, var) for var in vars]
413418
self.dim = sum(initial_point[v.name].size for v in vars)
414419

415420
if order == "random":
@@ -490,6 +495,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
490495

491496
model = pm.modelcontext(model)
492497

498+
vars = [model.rvs_to_values.get(var, var) for var in vars]
493499
vars = pm.inputvars(vars)
494500

495501
initial_point = model.initial_point
@@ -697,6 +703,8 @@ def __init__(
697703

698704
if vars is None:
699705
vars = model.cont_vars
706+
else:
707+
vars = [model.rvs_to_values.get(var, var) for var in vars]
700708
vars = pm.inputvars(vars)
701709

702710
if S is None:
@@ -846,6 +854,8 @@ def __init__(
846854

847855
if vars is None:
848856
vars = model.cont_vars
857+
else:
858+
vars = [model.rvs_to_values.get(var, var) for var in vars]
849859
vars = pm.inputvars(vars)
850860

851861
if S is None:

pymc3/step_methods/mlda.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(self, *args, **kwargs):
7474
value_vars = kwargs.get("vars", None)
7575
if value_vars is None:
7676
value_vars = model.value_vars
77+
else:
78+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
7779
value_vars = pm.inputvars(value_vars)
7880
shared = pm.make_shared_replacements(initial_values, value_vars, model)
7981

@@ -142,6 +144,8 @@ def __init__(self, *args, **kwargs):
142144
value_vars = kwargs.get("vars", None)
143145
if value_vars is None:
144146
value_vars = model.value_vars
147+
else:
148+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
145149
value_vars = pm.inputvars(value_vars)
146150
shared = pm.make_shared_replacements(initial_values, value_vars, model)
147151

@@ -218,7 +222,7 @@ class MLDA(ArrayStepShared):
218222
Note this list excludes the model passed to the model
219223
argument above, which is the finest available.
220224
vars : list
221-
List of variables for sampler
225+
List of value variables for sampler
222226
base_sampler : string
223227
Sampler used in the base (coarsest) chain. Can be 'Metropolis' or
224228
'DEMetropolisZ'. Defaults to 'DEMetropolisZ'.
@@ -549,6 +553,8 @@ def __init__(
549553
# Process model variables
550554
if value_vars is None:
551555
value_vars = model.value_vars
556+
else:
557+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
552558
value_vars = pm.inputvars(value_vars)
553559
self.vars = value_vars
554560
self.var_names = [var.name for var in self.vars]

pymc3/step_methods/pgbart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class PGBART(ArrayStepShared):
3939
Parameters
4040
----------
4141
vars: list
42-
List of variables for sampler
42+
List of value variables for sampler
4343
num_particles : int
4444
Number of particles for the conditional SMC sampler. Defaults to 10
4545
max_stages : int

pymc3/step_methods/sgmcmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class BaseStochasticGradient(ArrayStepShared):
8787
Parameters
8888
----------
8989
vars: list
90-
List of variables for sampler
90+
List of value variables for sampler
9191
batch_size`: int
9292
Batch Size for each step
9393
total_size: int
@@ -132,6 +132,8 @@ def __init__(
132132

133133
if vars is None:
134134
vars = model.value_vars
135+
else:
136+
vars = [model.rvs_to_values.get(var, var) for var in vars]
135137

136138
vars = inputvars(vars)
137139

pymc3/step_methods/slicer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Slice(ArrayStep):
3535
Parameters
3636
----------
3737
vars: list
38-
List of variables for sampler.
38+
List of value variables for sampler.
3939
w: float
4040
Initial width of slice (Defaults to 1).
4141
tune: bool
@@ -57,6 +57,8 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
5757

5858
if vars is None:
5959
vars = self.model.cont_vars
60+
else:
61+
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
6062
vars = inputvars(vars)
6163

6264
super().__init__(vars, [self.model.fastlogp], **kwargs)

0 commit comments

Comments
 (0)