Skip to content

Commit c3afa00

Browse files
authored
Vi summary (#2230)
* add draft * add draft * move and rename * better docstring * added tests
1 parent b505f47 commit c3afa00

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

pymc3/tests/test_variational_inference.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def test_fit(method, kwargs, error):
395395
'ord',
396396
[1, 2, np.inf]
397397
)
398-
def test_callbacks(diff, ord):
398+
def test_callbacks_convergence(diff, ord):
399399
cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord)
400400

401401
class _approx:
@@ -406,3 +406,27 @@ class _approx:
406406
with pytest.raises(StopIteration):
407407
cb(approx, None, 1)
408408
cb(approx, None, 10)
409+
410+
411+
def test_tracker_callback():
412+
import time
413+
tracker = pm.callbacks.Tracker(
414+
ints=lambda *t: t[-1],
415+
ints2=lambda ap, h, j: j,
416+
time=time.time,
417+
)
418+
for i in range(10):
419+
tracker(None, None, i)
420+
assert 'time' in tracker.hist
421+
assert 'ints' in tracker.hist
422+
assert 'ints2' in tracker.hist
423+
assert (len(tracker['ints'])
424+
== len(tracker['ints2'])
425+
== len(tracker['time'])
426+
== 10)
427+
assert tracker['ints'] == tracker['ints2'] == list(range(10))
428+
tracker = pm.callbacks.Tracker(
429+
bad=lambda t: t # bad signature
430+
)
431+
with pytest.raises(TypeError):
432+
tracker(None, None, 1)

pymc3/variational/callbacks.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import collections
2+
13
import numpy as np
24

35
__all__ = [
46
'Callback',
5-
'CheckParametersConvergence'
7+
'CheckParametersConvergence',
8+
'Tracker'
69
]
710

811

@@ -76,3 +79,65 @@ def __call__(self, approx, _, i):
7679
@staticmethod
7780
def flatten_shared(shared_list):
7881
return np.concatenate([sh.get_value().flatten() for sh in shared_list])
82+
83+
84+
class Tracker(Callback):
85+
"""
86+
Helper class to record arbitrary stats during VI
87+
88+
It is possible to pass a function that takes no arguments
89+
If call fails then (approx, hist, i) are passed
90+
91+
92+
Parameters
93+
----------
94+
kwargs : key word arguments
95+
keys mapping statname to callable that records the stat
96+
97+
Examples
98+
--------
99+
Consider we want time on each iteration
100+
>>> import time
101+
>>> tracker = Tracker(time=time.time)
102+
>>> with model:
103+
... approx = pm.fit(callbacks=[tracker])
104+
105+
Time can be accessed via :code:`tracker['time']` now
106+
For more complex summary one can use callable that takes
107+
(approx, hist, i) as arguments
108+
>>> with model:
109+
... my_callable = lambda ap, h, i: h[-1]
110+
... tracker = Tracker(some_stat=my_callable)
111+
... approx = pm.fit(callbacks=[tracker])
112+
113+
Multiple stats are valid too
114+
>>> with model:
115+
... tracker = Tracker(some_stat=my_callable, time=time.time)
116+
... approx = pm.fit(callbacks=[tracker])
117+
"""
118+
def __init__(self, **kwargs):
119+
self.whatchdict = kwargs
120+
self.hist = collections.defaultdict(list)
121+
122+
def record(self, approx, hist, i):
123+
for key, fn in self.whatchdict.items():
124+
try:
125+
res = fn()
126+
# if `*t` argument is used
127+
# fail will be somehow detected.
128+
# We want both calls to be tried.
129+
# Upper one has more priority as
130+
# arbitrary functions can have some
131+
# defaults in positionals. Bad idea
132+
# to try fn(approx, hist, i) first
133+
except Exception:
134+
res = fn(approx, hist, i)
135+
self.hist[key].append(res)
136+
137+
def clear(self):
138+
self.hist = collections.defaultdict(list)
139+
140+
def __getitem__(self, item):
141+
return self.hist[item]
142+
143+
__call__ = record

pymc3/variational/inference.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
import logging
44
import warnings
5-
import tqdm
65

76
import numpy as np
7+
import tqdm
88

99
import pymc3 as pm
10+
from pymc3.variational import test_functions
1011
from pymc3.variational.approximations import MeanField, FullRank, Empirical
1112
from pymc3.variational.operators import KL, KSD, AKSD
1213
from pymc3.variational.opvi import Approximation
13-
from pymc3.variational import test_functions
14-
1514

1615
logger = logging.getLogger(__name__)
1716

0 commit comments

Comments
 (0)