22
22
23
23
import pymc as pm
24
24
25
- from pymc .variational import opvi , test_functions
25
+ from pymc .variational import test_functions
26
26
from pymc .variational .approximations import Empirical , FullRank , MeanField
27
27
from pymc .variational .operators import KL , KSD
28
28
@@ -334,7 +334,7 @@ class ADVI(KLqp):
334
334
The last ones are local random variables
335
335
:math:`{\cal Z}=\{\mathbf{z}_{i}\}_{i=1}^{N}`, where
336
336
:math:`\mathbf{z}_{i}=\{\mathbf{z}_{i}^{k}\}_{k=1}^{V_{l}}`.
337
- These RVs are used only in AEVB.
337
+ These RVs are used only in AEVB (which is not implemented in PyMC) .
338
338
339
339
The goal of ADVI is to approximate the posterior distribution
340
340
:math:`p(\Theta,{\cal Z}|{\cal Y})` by variational posterior
@@ -408,8 +408,8 @@ class ADVI(KLqp):
408
408
409
409
- The probabilistic model
410
410
411
- `model` with three types of RVs (`observed_RVs`,
412
- `global_RVs` and `local_RVs` ).
411
+ `model` with two types of RVs (`observed_RVs`,
412
+ `global_RVs`).
413
413
414
414
- (optional) Minibatches
415
415
@@ -428,10 +428,6 @@ class ADVI(KLqp):
428
428
429
429
Parameters
430
430
----------
431
- local_rv: dict[var->tuple]
432
- mapping {model_variable -> approx params}
433
- Local Vars are used for Autoencoding Variational Bayes
434
- See (AEVB; Kingma and Welling, 2014) for details
435
431
model: :class:`pymc.Model`
436
432
PyMC model for inference
437
433
random_seed: None or int
@@ -463,10 +459,6 @@ class FullRankADVI(KLqp):
463
459
464
460
Parameters
465
461
----------
466
- local_rv: dict[var->tuple]
467
- mapping {model_variable -> approx params}
468
- Local Vars are used for Autoencoding Variational Bayes
469
- See (AEVB; Kingma and Welling, 2014) for details
470
462
model: :class:`pymc.Model`
471
463
PyMC model for inference
472
464
random_seed: None or int
@@ -571,8 +563,6 @@ def __init__(
571
563
kernel = test_functions .rbf ,
572
564
** kwargs ,
573
565
):
574
- if kwargs .get ("local_rv" ) is not None :
575
- raise opvi .AEVBInferenceError ("SVGD does not support local groups" )
576
566
empirical = Empirical (
577
567
size = n_particles ,
578
568
jitter = jitter ,
@@ -639,9 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
639
629
"is often **underestimated** when using temperature = 1."
640
630
)
641
631
if approx is None :
642
- approx = FullRank (
643
- model = kwargs .pop ("model" , None ), local_rv = kwargs .pop ("local_rv" , None )
644
- )
632
+ approx = FullRank (model = kwargs .pop ("model" , None ))
645
633
super ().__init__ (estimator = estimator , approx = approx , kernel = kernel , ** kwargs )
646
634
647
635
def fit (
0 commit comments