Skip to content

Commit 5fcc3e5

Browse files
committed
Convert RVs to value vars in step methods
1 parent d926746 commit 5fcc3e5

File tree

11 files changed

+98
-15
lines changed

11 files changed

+98
-15
lines changed

pymc3/step_methods/arraystep.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ArrayStep(BlockedStep):
134134
Parameters
135135
----------
136136
vars: list
137-
List of variables for sampler.
137+
List of value variables for sampler.
138138
fs: list of logp Aesara functions
139139
allvars: Boolean (default False)
140140
blocked: Boolean (default True)
@@ -190,7 +190,7 @@ def __init__(self, vars, shared, blocked=True):
190190
"""
191191
Parameters
192192
----------
193-
vars: list of sampling variables
193+
vars: list of sampling value variables
194194
shared: dict of Aesara variable -> shared variable
195195
blocked: Boolean (default True)
196196
"""
@@ -235,7 +235,7 @@ def __init__(self, vars, shared, blocked=True):
235235
"""
236236
Parameters
237237
----------
238-
vars: list of sampling variables
238+
vars: list of sampling value variables
239239
shared: dict of Aesara variable -> shared variable
240240
blocked: Boolean (default True)
241241
"""

pymc3/step_methods/elliptical_slice.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import numpy.random as nr
1818

19+
from pymc3.aesaraf import inputvars
1920
from pymc3.model import modelcontext
2021
from pymc3.step_methods.arraystep import ArrayStep, Competence
2122

@@ -61,7 +62,7 @@ class EllipticalSlice(ArrayStep):
6162
Parameters
6263
----------
6364
vars: list
64-
List of variables for sampler.
65+
List of value variables for sampler.
6566
prior_cov: array, optional
6667
Covariance matrix of the multivariate Gaussian prior.
6768
prior_chol: array, optional
@@ -88,6 +89,8 @@ def __init__(self, vars=None, prior_cov=None, prior_chol=None, model=None, **kwa
8889

8990
if vars is None:
9091
vars = self.model.cont_vars
92+
else:
93+
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
9194
vars = inputvars(vars)
9295

9396
super().__init__(vars, [self.model.fastlogp], **kwargs)

pymc3/step_methods/hmc/base_hmc.py

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __init__(
8989

9090
if vars is None:
9191
vars = self._model.cont_vars
92+
else:
93+
vars = [self._model.rvs_to_values.get(var, var) for var in vars]
9294

9395
super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)
9496

pymc3/step_methods/hmc/hmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
6060
Parameters
6161
----------
6262
vars: list, default=None
63-
List of Aesara variables. If None, all continuous RVs from the
63+
List of value variables. If None, all continuous RVs from the
6464
model are included.
6565
path_length: float, default=2
6666
Total length to travel

pymc3/step_methods/hmc/nuts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs)
115115
Parameters
116116
----------
117117
vars: list, default=None
118-
List of Aesara variables. If None, all continuous RVs from the
118+
List of value variables. If None, all continuous RVs from the
119119
model are included.
120120
Emax: float, default 1000
121121
Maximum energy change allowed during leapfrog steps. Larger

pymc3/step_methods/metropolis.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
Parameters
131131
----------
132132
vars: list
133-
List of variables for sampler
133+
List of value variables for sampler
134134
S: standard deviation or covariance matrix
135135
Some measure of variance to parameterize proposal distribution
136136
proposal_dist: function
@@ -153,6 +153,8 @@ def __init__(
153153

154154
if vars is None:
155155
vars = model.value_vars
156+
else:
157+
vars = [model.rvs_to_values.get(var, var) for var in vars]
156158
vars = pm.inputvars(vars)
157159

158160
if S is None:
@@ -288,7 +290,7 @@ class BinaryMetropolis(ArrayStep):
288290
Parameters
289291
----------
290292
vars: list
291-
List of variables for sampler
293+
List of value variables for sampler
292294
scaling: scalar or array
293295
Initial scale factor for proposal. Defaults to 1.
294296
tune: bool
@@ -321,6 +323,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
321323
self.steps_until_tune = tune_interval
322324
self.accepted = 0
323325

326+
vars = [model.rvs_to_values.get(var, var) for var in vars]
327+
324328
if not all([v.dtype in pm.discrete_types for v in vars]):
325329
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
326330

@@ -388,7 +392,7 @@ class BinaryGibbsMetropolis(ArrayStep):
388392
Parameters
389393
----------
390394
vars: list
391-
List of variables for sampler
395+
List of value variables for sampler
392396
order: list or 'random'
393397
List of integers indicating the Gibbs update order
394398
e.g., [0, 2, 1, ...]. Default is random
@@ -410,6 +414,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
410414
self.transit_p = transit_p
411415

412416
initial_point = model.initial_point
417+
vars = [model.rvs_to_values.get(var, var) for var in vars]
413418
self.dim = sum(initial_point[v.name].size for v in vars)
414419

415420
if order == "random":
@@ -490,6 +495,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
490495

491496
model = pm.modelcontext(model)
492497

498+
vars = [model.rvs_to_values.get(var, var) for var in vars]
493499
vars = pm.inputvars(vars)
494500

495501
initial_point = model.initial_point
@@ -697,6 +703,8 @@ def __init__(
697703

698704
if vars is None:
699705
vars = model.cont_vars
706+
else:
707+
vars = [model.rvs_to_values.get(var, var) for var in vars]
700708
vars = pm.inputvars(vars)
701709

702710
if S is None:
@@ -846,6 +854,8 @@ def __init__(
846854

847855
if vars is None:
848856
vars = model.cont_vars
857+
else:
858+
vars = [model.rvs_to_values.get(var, var) for var in vars]
849859
vars = pm.inputvars(vars)
850860

851861
if S is None:

pymc3/step_methods/mlda.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(self, *args, **kwargs):
7474
value_vars = kwargs.get("vars", None)
7575
if value_vars is None:
7676
value_vars = model.value_vars
77+
else:
78+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
7779
value_vars = pm.inputvars(value_vars)
7880
shared = pm.make_shared_replacements(initial_values, value_vars, model)
7981

@@ -142,6 +144,8 @@ def __init__(self, *args, **kwargs):
142144
value_vars = kwargs.get("vars", None)
143145
if value_vars is None:
144146
value_vars = model.value_vars
147+
else:
148+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
145149
value_vars = pm.inputvars(value_vars)
146150
shared = pm.make_shared_replacements(initial_values, value_vars, model)
147151

@@ -218,7 +222,7 @@ class MLDA(ArrayStepShared):
218222
Note this list excludes the model passed to the model
219223
argument above, which is the finest available.
220224
vars : list
221-
List of variables for sampler
225+
List of value variables for sampler
222226
base_sampler : string
223227
Sampler used in the base (coarsest) chain. Can be 'Metropolis' or
224228
'DEMetropolisZ'. Defaults to 'DEMetropolisZ'.
@@ -549,6 +553,8 @@ def __init__(
549553
# Process model variables
550554
if value_vars is None:
551555
value_vars = model.value_vars
556+
else:
557+
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
552558
value_vars = pm.inputvars(value_vars)
553559
self.vars = value_vars
554560
self.var_names = [var.name for var in self.vars]

pymc3/step_methods/pgbart.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class PGBART(ArrayStepShared):
3939
Parameters
4040
----------
4141
vars: list
42-
List of variables for sampler
42+
List of value variables for sampler
4343
num_particles : int
4444
Number of particles for the conditional SMC sampler. Defaults to 10
4545
max_stages : int

pymc3/step_methods/sgmcmc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class BaseStochasticGradient(ArrayStepShared):
8787
Parameters
8888
----------
8989
vars: list
90-
List of variables for sampler
90+
List of value variables for sampler
9191
batch_size`: int
9292
Batch Size for each step
9393
total_size: int
@@ -132,6 +132,8 @@ def __init__(
132132

133133
if vars is None:
134134
vars = model.value_vars
135+
else:
136+
vars = [model.rvs_to_values.get(var, var) for var in vars]
135137

136138
vars = inputvars(vars)
137139

pymc3/step_methods/slicer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Slice(ArrayStep):
3535
Parameters
3636
----------
3737
vars: list
38-
List of variables for sampler.
38+
List of value variables for sampler.
3939
w: float
4040
Initial width of slice (Defaults to 1).
4141
tune: bool
@@ -57,6 +57,8 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
5757

5858
if vars is None:
5959
vars = self.model.cont_vars
60+
else:
61+
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
6062
vars = inputvars(vars)
6163

6264
super().__init__(vars, [self.model.fastlogp], **kwargs)

pymc3/tests/test_step.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -618,8 +618,8 @@ def test_step_categorical(self):
618618
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
619619
with model:
620620
steps = (
621-
CategoricalGibbsMetropolis(model.x, proposal="uniform"),
622-
CategoricalGibbsMetropolis(model.x, proposal="proportional"),
621+
CategoricalGibbsMetropolis([model.x], proposal="uniform"),
622+
CategoricalGibbsMetropolis([model.x], proposal="proportional"),
623623
)
624624
for step in steps:
625625
idata = sample(8000, tune=0, step=step, start=start, model=model, random_seed=1)
@@ -1767,3 +1767,61 @@ def perform(self, node, inputs, outputs):
17671767
)
17681768
assert Q_1_0.mean(axis=1) == 0.0
17691769
assert Q_2_1.mean(axis=1) == 0.0
1770+
1771+
1772+
class TestRVsAssignmentSteps:
1773+
"""
1774+
Test that step methods convert input RVs to respective value vars
1775+
Step methods are tested with one and two variables to cover compound
1776+
the special branches in `BlockedStep.__new__`
1777+
"""
1778+
1779+
@pytest.mark.parametrize(
1780+
"step, step_kwargs",
1781+
[
1782+
(NUTS, {}),
1783+
(HamiltonianMC, {}),
1784+
(Metropolis, {}),
1785+
(Slice, {}),
1786+
(EllipticalSlice, {"prior_cov": np.eye(1)}),
1787+
(DEMetropolis, {}),
1788+
(DEMetropolisZ, {}),
1789+
# (MLDA, {}), # TODO
1790+
],
1791+
)
1792+
def test_continuous_steps(self, step, step_kwargs):
1793+
with Model() as m:
1794+
c1 = HalfNormal("c1")
1795+
c2 = HalfNormal("c2")
1796+
1797+
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
1798+
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
1799+
step([c1, c2], **step_kwargs).vars
1800+
)
1801+
1802+
@pytest.mark.parametrize(
1803+
"step, step_kwargs",
1804+
[
1805+
(BinaryGibbsMetropolis, {}),
1806+
(CategoricalGibbsMetropolis, {}),
1807+
],
1808+
)
1809+
def test_discrete_steps(self, step, step_kwargs):
1810+
with Model() as m:
1811+
d1 = Bernoulli("d1", p=0.5)
1812+
d2 = Bernoulli("d2", p=0.5)
1813+
1814+
assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars
1815+
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
1816+
step([d1, d2], **step_kwargs).vars
1817+
)
1818+
1819+
def test_compound_step(self):
1820+
with Model() as m:
1821+
c1 = HalfNormal("c1")
1822+
c2 = HalfNormal("c2")
1823+
1824+
step1 = NUTS([c1])
1825+
step2 = NUTS([c2])
1826+
step = CompoundStep([step1, step2])
1827+
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)

0 commit comments

Comments
 (0)