Skip to content

Commit 84db639

Browse files
authored
add total gradient norm to VI (#2257)
* add total gradient norm to VI * fix printing problems when loss is big * fix lint * add temperature * Unused Import * fix test
1 parent 98a2e03 commit 84db639

File tree

5 files changed

+91
-38
lines changed

5 files changed

+91
-38
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TestApproximates:
7070
class Base(SeededTest):
7171
inference = None
7272
NITER = 12000
73-
optimizer = pm.adagrad_window(learning_rate=0.01)
73+
optimizer = pm.adagrad_window(learning_rate=0.01, n_win=50)
7474
conv_cb = property(lambda self: [
7575
pm.callbacks.CheckParametersConvergence(
7676
every=500,
@@ -152,7 +152,6 @@ def test_optimizer_with_full_data(self):
152152
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
153153
Normal('x', mu=mu_, sd=sd, observed=data)
154154
inf = self.inference(start={})
155-
inf.fit(10)
156155
approx = inf.fit(self.NITER,
157156
obj_optimizer=self.optimizer,
158157
callbacks=self.conv_cb,)
@@ -295,11 +294,9 @@ class TestSVGD(TestApproximates.Base):
295294

296295

297296
class TestASVGD(TestApproximates.Base):
298-
NITER = 15000
299-
inference = ASVGD
297+
NITER = 5000
298+
inference = functools.partial(ASVGD, temperature=1.5)
300299
test_aevb = _test_aevb
301-
optimizer = pm.adagrad_window(learning_rate=0.002)
302-
conv_cb = []
303300

304301

305302
class TestEmpirical(SeededTest):
@@ -366,12 +363,13 @@ def test_init_from_noize(self):
366363
(_advi, dict(start={}), None),
367364
(_fullrank_advi, dict(), None),
368365
(_svgd, dict(), None),
369-
('advi', dict(), None),
366+
('advi', dict(total_grad_norm_constraint=10), None),
370367
('advi->fullrank_advi', dict(frac=.1), None),
371368
('advi->fullrank_advi', dict(frac=1), ValueError),
372369
('fullrank_advi', dict(), None),
373-
('svgd', dict(), None),
370+
('svgd', dict(total_grad_norm_constraint=10), None),
374371
('svgd', dict(start={}), None),
372+
('asvgd', dict(start={}, total_grad_norm_constraint=10), None),
375373
('svgd', dict(local_rv={_model.free_RVs[0]: (0, 1)}), ValueError)
376374
]
377375
)

pymc3/variational/inference.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ class Inference(object):
4141
See (AEVB; Kingma and Welling, 2014) for details
4242
model : Model
4343
PyMC3 Model
44+
op_kwargs : dict
45+
kwargs passed to :class:`Operator`
4446
kwargs : kwargs
4547
additional kwargs for :class:`Approximation`
4648
"""
4749

48-
def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
50+
def __init__(self, op, approx, tf, local_rv=None, model=None, op_kwargs=None, **kwargs):
51+
if op_kwargs is None:
52+
op_kwargs = dict()
4953
self.hist = np.asarray(())
5054
if isinstance(approx, type) and issubclass(approx, Approximation):
5155
approx = approx(
@@ -56,7 +60,7 @@ def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
5660
else: # pragma: no cover
5761
raise TypeError(
5862
'approx should be Approximation instance or Approximation subclass')
59-
self.objective = op(approx)(tf)
63+
self.objective = op(approx, **op_kwargs)(tf)
6064

6165
approx = property(lambda self: self.objective.approx)
6266

@@ -146,7 +150,11 @@ def _iterate_without_loss(self, _, step_func, progress, callbacks):
146150
def _iterate_with_loss(self, n, step_func, progress, callbacks):
147151
def _infmean(input_array):
148152
"""Return the mean of the finite values of the array"""
149-
return np.mean(np.asarray(input_array)[np.isfinite(input_array)])
153+
input_array = input_array[np.isfinite(input_array)].astype('float64')
154+
if len(input_array) == 0:
155+
return np.nan
156+
else:
157+
return np.mean(input_array)
150158
scores = np.empty(n)
151159
scores[:] = np.nan
152160
i = 0
@@ -531,6 +539,8 @@ class SVGD(Inference):
531539
PyMC3 model for inference
532540
kernel : `callable`
533541
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
542+
temperature : float
543+
parameter responsible for exploration, higher temperature gives more broad posterior estimate
534544
scale_cost_to_minibatch : bool, default False
535545
Scale cost to minibatch instead of full dataset
536546
start : `dict`
@@ -548,10 +558,14 @@ class SVGD(Inference):
548558
- Qiang Liu, Dilin Wang (2016)
549559
Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
550560
arXiv:1608.04471
561+
562+
- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
563+
Stein Variational Policy Gradient
564+
arXiv:1704.02399
551565
"""
552566

553567
def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_functions.rbf,
554-
scale_cost_to_minibatch=False, start=None, histogram=None,
568+
temperature=1, scale_cost_to_minibatch=False, start=None, histogram=None,
555569
random_seed=None, local_rv=None):
556570
if histogram is None:
557571
histogram = Empirical.from_noise(
@@ -593,6 +607,8 @@ class ASVGD(Inference):
593607
See (AEVB; Kingma and Welling, 2014) for details
594608
kernel : `callable`
595609
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
610+
temperature : float
611+
parameter responsible for exploration, higher temperature gives more broad posterior estimate
596612
model : :class:`Model`
597613
kwargs : kwargs for :class:`Approximation`
598614
@@ -604,17 +620,22 @@ class ASVGD(Inference):
604620
605621
- Dilin Wang, Qiang Liu (2016)
606622
Learning to Draw Samples: With Application to Amortized MLE for Generative Adversarial Learning
607-
https://arxiv.org/abs/1611.01722
623+
arXiv:1611.01722
624+
625+
- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
626+
Stein Variational Policy Gradient
627+
arXiv:1704.02399
608628
"""
609629

610630
def __init__(self, approx=FullRank, local_rv=None,
611-
kernel=test_functions.rbf, model=None, **kwargs):
631+
kernel=test_functions.rbf, temperature=1, model=None, **kwargs):
612632
super(ASVGD, self).__init__(
613633
op=AKSD,
614634
approx=approx,
615635
local_rv=local_rv,
616636
tf=kernel,
617637
model=model,
638+
op_kwargs=dict(temperature=temperature),
618639
**kwargs
619640
)
620641

pymc3/variational/operators.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import warnings
12
from theano import theano, tensor as tt
23
from pymc3.variational.opvi import Operator, ObjectiveFunction, _warn_not_used
34
from pymc3.variational.stein import Stein
4-
from pymc3.variational import updates
55
import pymc3 as pm
66

77
__all__ = [
@@ -63,7 +63,6 @@ def __call__(self, z, **kwargs):
6363
grad *= pm.floatX(-1)
6464
grad = theano.clone(grad, {op.input_matrix: z})
6565
grad = tt.grad(None, params, known_grads={z: grad})
66-
grad = updates.total_norm_constraint(grad, 10)
6766
return grad
6867

6968

@@ -97,15 +96,29 @@ class KSD(Operator):
9796
SUPPORT_AEVB = False
9897
OBJECTIVE = KSDObjective
9998

100-
def __init__(self, approx):
99+
def __init__(self, approx, temperature=1):
101100
Operator.__init__(self, approx)
101+
self.temperature = temperature
102102
self.input_matrix = tt.matrix('KSD input matrix')
103103

104104
def apply(self, f):
105105
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
106-
stein = Stein(self.approx, f, self.input_matrix)
106+
stein = Stein(
107+
approx=self.approx,
108+
kernel=f,
109+
input_matrix=self.input_matrix,
110+
temperature=self.temperature)
107111
return pm.floatX(-1) * stein.grad
108112

109113

110114
class AKSD(KSD):
115+
def __init__(self, approx, temperature=1):
116+
warnings.warn('You are using experimental inference Operator. '
117+
'It requires careful choice of temperature, default is 1. '
118+
'Default temperature works well for low dimensional problems and '
119+
'for significant `n_obj_mc`. Temperature > 1 gives more exploration '
120+
'power to algorithm, < 1 leads to undesirable results. Please take '
121+
'it in account when looking at inference result. Posterior variance '
122+
'is often **underestimated** when using temperature = 1.', stacklevel=2)
123+
super(AKSD, self).__init__(approx, temperature)
111124
SUPPORT_AEVB = True

pymc3/variational/opvi.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def random(self, size=None):
9898
return self.op.approx.random(size)
9999

100100
def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
101-
more_obj_params=None, more_tf_params=None, more_updates=None, more_replacements=None):
101+
more_obj_params=None, more_tf_params=None, more_updates=None,
102+
more_replacements=None, total_grad_norm_constraint=None):
102103
"""Calculates gradients for objective function, test function and then
103104
constructs updates for optimization step
104105
@@ -120,27 +121,24 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
120121
Add custom updates to resulting updates
121122
more_replacements : `dict`
122123
Apply custom replacements before calculating gradients
124+
total_grad_norm_constraint : `float`
125+
Bounds gradient norm, prevents exploding gradient problem
123126
124127
Returns
125128
-------
126129
:class:`ObjectiveUpdates`
127130
"""
128-
if more_obj_params is None:
129-
more_obj_params = []
130-
if more_tf_params is None:
131-
more_tf_params = []
132131
if more_updates is None:
133132
more_updates = dict()
134-
if more_replacements is None:
135-
more_replacements = dict()
136133
resulting_updates = ObjectiveUpdates()
137134
if self.test_params:
138135
self.add_test_updates(
139136
resulting_updates,
140137
tf_n_mc=tf_n_mc,
141138
test_optimizer=test_optimizer,
142139
more_tf_params=more_tf_params,
143-
more_replacements=more_replacements
140+
more_replacements=more_replacements,
141+
total_grad_norm_constraint=total_grad_norm_constraint
144142
)
145143
else:
146144
if tf_n_mc is not None:
@@ -152,30 +150,47 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
152150
obj_n_mc=obj_n_mc,
153151
obj_optimizer=obj_optimizer,
154152
more_obj_params=more_obj_params,
155-
more_replacements=more_replacements
153+
more_replacements=more_replacements,
154+
total_grad_norm_constraint=total_grad_norm_constraint
156155
)
157156
resulting_updates.update(more_updates)
158157
return resulting_updates
159158

160159
def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
161-
more_tf_params=None, more_replacements=None):
160+
more_tf_params=None, more_replacements=None,
161+
total_grad_norm_constraint=None):
162+
if more_tf_params is None:
163+
more_tf_params = []
164+
if more_replacements is None:
165+
more_replacements = dict()
162166
tf_z = self.get_input(tf_n_mc)
163167
tf_target = self(tf_z, more_tf_params=more_tf_params)
164168
tf_target = theano.clone(tf_target, more_replacements, strict=False)
169+
grads = pm.updates.get_or_compute_grads(tf_target, self.obj_params + more_tf_params)
170+
if total_grad_norm_constraint is not None:
171+
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
165172
updates.update(
166173
test_optimizer(
167-
tf_target,
174+
grads,
168175
self.test_params +
169176
more_tf_params))
170177

171178
def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
172-
more_obj_params=None, more_replacements=None):
179+
more_obj_params=None, more_replacements=None,
180+
total_grad_norm_constraint=None):
181+
if more_obj_params is None:
182+
more_obj_params = []
183+
if more_replacements is None:
184+
more_replacements = dict()
173185
obj_z = self.get_input(obj_n_mc)
174186
obj_target = self(obj_z, more_obj_params=more_obj_params)
175187
obj_target = theano.clone(obj_target, more_replacements, strict=False)
188+
grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
189+
if total_grad_norm_constraint is not None:
190+
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
176191
updates.update(
177192
obj_optimizer(
178-
obj_target,
193+
grads,
179194
self.obj_params +
180195
more_obj_params))
181196
if self.op.RETURNS_LOSS:
@@ -189,8 +204,9 @@ def get_input(self, n_mc):
189204
def step_function(self, obj_n_mc=None, tf_n_mc=None,
190205
obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
191206
more_obj_params=None, more_tf_params=None,
192-
more_updates=None, more_replacements=None, score=False,
193-
fn_kwargs=None):
207+
more_updates=None, more_replacements=None,
208+
total_grad_norm_constraint=None,
209+
score=False, fn_kwargs=None):
194210
R"""Step function that should be called on each optimization step.
195211
196212
Generally it solves the following problem:
@@ -215,6 +231,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
215231
Add custom params for test function optimizer
216232
more_updates : `dict`
217233
Add custom updates to resulting updates
234+
total_grad_norm_constraint : `float`
235+
Bounds gradient norm, prevents exploding gradient problem
218236
score : `bool`
219237
calculate loss on each step? Defaults to False for speed
220238
fn_kwargs : `dict`
@@ -236,7 +254,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
236254
more_obj_params=more_obj_params,
237255
more_tf_params=more_tf_params,
238256
more_updates=more_updates,
239-
more_replacements=more_replacements)
257+
more_replacements=more_replacements,
258+
total_grad_norm_constraint=total_grad_norm_constraint)
240259
if score:
241260
step_fn = theano.function(
242261
[], updates.loss, updates=updates, **fn_kwargs)

pymc3/variational/stein.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from theano import theano, tensor as tt
22
from pymc3.variational.test_functions import rbf
3-
from pymc3.theanof import memoize
3+
from pymc3.theanof import memoize, floatX
44

55
__all__ = [
66
'Stein'
77
]
88

99

1010
class Stein(object):
11-
def __init__(self, approx, kernel=rbf, input_matrix=None):
11+
def __init__(self, approx, kernel=rbf, input_matrix=None, temperature=1):
1212
self.approx = approx
13+
self.temperature = floatX(temperature)
1314
self._kernel_f = kernel
1415
if input_matrix is None:
1516
input_matrix = tt.matrix('stein_input_matrix')
@@ -22,8 +23,9 @@ def grad(self):
2223
t = self.approx.normalizing_constant
2324
Kxy, dxkxy = self.Kxy, self.dxkxy
2425
dlogpdx = self.dlogp # Normalized
25-
n = self.input_matrix.shape[0].astype('float32')
26-
svgd_grad = (tt.dot(Kxy, dlogpdx) + dxkxy/t) / n
26+
n = floatX(self.input_matrix.shape[0])
27+
temperature = self.temperature
28+
svgd_grad = (tt.dot(Kxy, dlogpdx)/temperature + dxkxy/t) / n
2729
return svgd_grad
2830

2931
@property

0 commit comments

Comments
 (0)