Skip to content

Commit c492f64

Browse files
ferrinetwiecki
authored andcommitted
remove references to local_rvs
1 parent 367bdd6 commit c492f64

File tree

3 files changed

+5
-30
lines changed

3 files changed

+5
-30
lines changed

pymc/variational/approximations.py

-10
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,7 @@ class SingleGroupApproximation(Approximation):
313313
_group_class = None
314314

315315
def __init__(self, *args, **kwargs):
316-
local_rv = kwargs.get("local_rv")
317316
groups = [self._group_class(None, *args, **kwargs)]
318-
if local_rv is not None:
319-
groups.extend(
320-
[
321-
Group([v], params=p, local=True, model=kwargs.get("model"))
322-
for v, p in local_rv.items()
323-
]
324-
)
325317
super().__init__(groups, model=kwargs.get("model"))
326318

327319
def __getattr__(self, item):
@@ -360,8 +352,6 @@ class Empirical(SingleGroupApproximation):
360352
_group_class = EmpiricalGroup
361353

362354
def __init__(self, trace=None, size=None, **kwargs):
363-
if kwargs.get("local_rv", None) is not None:
364-
raise opvi.LocalGroupError("Empirical approximation does not support local variables")
365355
super().__init__(trace=trace, size=size, **kwargs)
366356

367357
def evaluate_over_trace(self, node):

pymc/variational/inference.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import pymc as pm
2424

25-
from pymc.variational import opvi, test_functions
25+
from pymc.variational import test_functions
2626
from pymc.variational.approximations import Empirical, FullRank, MeanField
2727
from pymc.variational.operators import KL, KSD
2828

@@ -334,7 +334,7 @@ class ADVI(KLqp):
334334
The last ones are local random variables
335335
:math:`{\cal Z}=\{\mathbf{z}_{i}\}_{i=1}^{N}`, where
336336
:math:`\mathbf{z}_{i}=\{\mathbf{z}_{i}^{k}\}_{k=1}^{V_{l}}`.
337-
These RVs are used only in AEVB.
337+
These RVs are used only in AEVB (which is not implemented in PyMC).
338338
339339
The goal of ADVI is to approximate the posterior distribution
340340
:math:`p(\Theta,{\cal Z}|{\cal Y})` by variational posterior
@@ -408,8 +408,8 @@ class ADVI(KLqp):
408408
409409
- The probabilistic model
410410
411-
`model` with three types of RVs (`observed_RVs`,
412-
`global_RVs` and `local_RVs`).
411+
`model` with two types of RVs (`observed_RVs`,
412+
`global_RVs`).
413413
414414
- (optional) Minibatches
415415
@@ -428,10 +428,6 @@ class ADVI(KLqp):
428428
429429
Parameters
430430
----------
431-
local_rv: dict[var->tuple]
432-
mapping {model_variable -> approx params}
433-
Local Vars are used for Autoencoding Variational Bayes
434-
See (AEVB; Kingma and Welling, 2014) for details
435431
model: :class:`pymc.Model`
436432
PyMC model for inference
437433
random_seed: None or int
@@ -463,10 +459,6 @@ class FullRankADVI(KLqp):
463459
464460
Parameters
465461
----------
466-
local_rv: dict[var->tuple]
467-
mapping {model_variable -> approx params}
468-
Local Vars are used for Autoencoding Variational Bayes
469-
See (AEVB; Kingma and Welling, 2014) for details
470462
model: :class:`pymc.Model`
471463
PyMC model for inference
472464
random_seed: None or int
@@ -571,8 +563,6 @@ def __init__(
571563
kernel=test_functions.rbf,
572564
**kwargs,
573565
):
574-
if kwargs.get("local_rv") is not None:
575-
raise opvi.AEVBInferenceError("SVGD does not support local groups")
576566
empirical = Empirical(
577567
size=n_particles,
578568
jitter=jitter,
@@ -639,9 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
639629
"is often **underestimated** when using temperature = 1."
640630
)
641631
if approx is None:
642-
approx = FullRank(
643-
model=kwargs.pop("model", None), local_rv=kwargs.pop("local_rv", None)
644-
)
632+
approx = FullRank(model=kwargs.pop("model", None))
645633
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)
646634

647635
def fit(

pymc/variational/opvi.py

-3
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,6 @@ class Approximation(WithMemoization):
11451145
- :class:`FullRank`
11461146
- :class:`Empirical`
11471147
1148-
Single group accepts `local_rv` keyword with dict mapping PyMC variables
1149-
to their local Group parameters dict
1150-
11511148
See Also
11521149
--------
11531150
:class:`Group`

0 commit comments

Comments
 (0)