forked from pymc-devs/pymc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgmcmc.py
226 lines (183 loc) · 7.34 KB
/
sgmcmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import OrderedDict
import aesara
import aesara.tensor as at
from pymc3.aesaraf import at_rng, make_shared_replacements
from pymc3.model import inputvars, modelcontext
from pymc3.step_methods.arraystep import ArrayStepShared
__all__ = []
EXPERIMENTAL_WARNING = (
"Warning: Stochastic Gradient based sampling methods are experimental step methods and not yet"
" recommended for use in PyMC3!"
)
def _value_error(cond, str):
"""Throws ValueError if cond is False"""
if not cond:
raise ValueError(str)
def _check_minibatches(minibatch_tensors, minibatches):
_value_error(isinstance(minibatch_tensors, list), "minibatch_tensors must be a list.")
_value_error(hasattr(minibatches, "__iter__"), "minibatches must be an iterator.")
def prior_dlogp(vars, model, flat_view):
"""Returns the gradient of the prior on the parameters as a vector of size D x 1"""
terms = at.concatenate([aesara.grad(var.logpt, var).flatten() for var in vars], axis=0)
dlogp = aesara.clone_replace(terms, flat_view.replacements, strict=False)
return dlogp
def elemwise_dlogL(vars, model, flat_view):
"""
Returns Jacobian of the log likelihood for each training datum wrt vars
as a matrix of size N x D
"""
# select one observed random variable
obs_var = model.observed_RVs[0]
# tensor of shape (batch_size,)
logL = obs_var.logp_elemwiset.sum(axis=tuple(range(1, obs_var.logp_elemwiset.ndim)))
# calculate fisher information
terms = []
for var in vars:
output, _ = aesara.scan(
lambda i, logX, v: aesara.grad(logX[i], v).flatten(),
sequences=[at.arange(logL.shape[0])],
non_sequences=[logL, var],
)
terms.append(output)
dlogL = aesara.clone_replace(
at.concatenate(terms, axis=1), flat_view.replacements, strict=False
)
return dlogL
class BaseStochasticGradient(ArrayStepShared):
R"""
BaseStochasticGradient Object
For working with BaseStochasticGradient Object
we need to supply the probabilistic model
(:code:`model`) with the data supplied to observed
variables of type `GeneratorOp`
Parameters
----------
vars: list
List of value variables for sampler
batch_size`: int
Batch Size for each step
total_size: int
Total size of the training data
step_size: float
Step size for the parameter update
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context)
random_seed: int
The seed to initialize the Random Stream
minibatches: iterator
If the observed RV is not a GeneratorOp then this parameter must not be None
minibatch_tensor: list of tensors
If the observed RV is not a GeneratorOp then this parameter must not be None
The length of this tensor should be the same as the next(minibatches)
Notes
-----
Defining a BaseStochasticGradient needs
custom implementation of the following methods:
- :code: `.mk_training_fn()`
Returns an Aesara function which is called for each sampling step
- :code: `._initialize_values()`
Returns None it creates class variables which are required for the training fn
"""
def __init__(
self,
vars=None,
batch_size=None,
total_size=None,
step_size=1.0,
model=None,
random_seed=None,
minibatches=None,
minibatch_tensors=None,
**kwargs
):
warnings.warn(EXPERIMENTAL_WARNING)
model = modelcontext(model)
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)
self.model = model
self.vars = vars
self.batch_size = batch_size
self.total_size = total_size
_value_error(
total_size != None or batch_size != None,
"total_size and batch_size of training data have to be specified",
)
self.expected_iter = int(total_size / batch_size)
# set random stream
self.random = None
if random_seed is None:
self.random = at_rng()
else:
self.random = at_rng(random_seed)
self.step_size = step_size
shared = make_shared_replacements(vars, model)
self.updates = OrderedDict()
# XXX: This needs to be refactored
self.q_size = None # int(sum(v.dsize for v in self.vars))
# This seems to be the only place that `Model.flatten` is used.
# TODO: Why not _actually_ flatten the variables?
# E.g. `flat_vars = at.concatenate([var.ravel() for var in vars])`
# or `set_subtensor` the `vars` into a `at.vector`?
flat_view = model.flatten(vars)
self.inarray = [flat_view.input]
self.dlog_prior = prior_dlogp(vars, model, flat_view)
self.dlogp_elemwise = elemwise_dlogL(vars, model, flat_view)
# XXX: This needs to be refactored
self.q_size = None # int(sum(v.dsize for v in self.vars))
if minibatch_tensors is not None:
_check_minibatches(minibatch_tensors, minibatches)
self.minibatches = minibatches
# Replace input shared variables with tensors
def is_shared(t):
return isinstance(t, aesara.compile.sharedvalue.SharedVariable)
tensors = [(t.type() if is_shared(t) else t) for t in minibatch_tensors]
updates = OrderedDict(
{t: t_ for t, t_ in zip(minibatch_tensors, tensors) if is_shared(t)}
)
self.minibatch_tensors = tensors
self.inarray += self.minibatch_tensors
self.updates.update(updates)
self._initialize_values()
super().__init__(vars, shared)
def _initialize_values(self):
"""Initializes the parameters for the stochastic gradient minibatch
algorithm"""
raise NotImplementedError
def mk_training_fn(self):
raise NotImplementedError
def training_complete(self):
"""Returns boolean if astep has been called expected iter number of times"""
return self.expected_iter == self.t
def astep(self, q0):
"""Perform a single update in the stochastic gradient method.
Returns new shared values and values sampled
The size and ordering of q0 and q must be the same
Parameters
-------
q0: list
List of shared values and values sampled from last estimate
Returns
-------
q
"""
if hasattr(self, "minibatch_tensors"):
return q0 + self.training_fn(q0, *next(self.minibatches))
else:
return q0 + self.training_fn(q0)