Skip to content

Commit 321c57e

Browse files
authored
Fixed minibatching in NN notebook (#773)
* Fixed minibatching in NN notebook * Updated text on prediction * Updated text on prediction
1 parent 18cb11d commit 321c57e

File tree

4 files changed

+120
-6058
lines changed

4 files changed

+120
-6058
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ build
88
jupyter_execute
99
_thumbnails
1010
examples/gallery.rst
11+
12+
pixi.lock

examples/variational_inference/bayesian_neural_network_advi.ipynb

+74-215
Large diffs are not rendered by default.

examples/variational_inference/bayesian_neural_network_advi.myst.md

+44-50
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: default
99
language: python
1010
name: python3
1111
---
@@ -114,7 +114,7 @@ A neural network is quite simple. The basic unit is a [perceptron](https://en.wi
114114
jupyter:
115115
outputs_hidden: true
116116
---
117-
def construct_nn():
117+
def construct_nn(batch_size=50):
118118
n_hidden = 5
119119
120120
# Initialize random weights between each layer
@@ -130,12 +130,13 @@ def construct_nn():
130130
}
131131
132132
with pm.Model(coords=coords) as neural_network:
133-
# Define minibatch variables
134-
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
135133
136134
# Define data variables using minibatches
137-
ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols"))
138-
ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id")
135+
X_data = pm.Data("X_data", X_train, dims=("obs_id", "train_cols"))
136+
Y_data = pm.Data("Y_data", Y_train, dims="obs_id")
137+
138+
# Define minibatch variables
139+
ann_input, ann_output = pm.Minibatch(X_data, Y_data, batch_size=batch_size)
139140
140141
# Weights from input to hidden layer
141142
weights_in_1 = pm.Normal(
@@ -161,7 +162,6 @@ def construct_nn():
161162
act_out,
162163
observed=ann_output,
163164
total_size=X_train.shape[0], # IMPORTANT for minibatches
164-
dims="obs_id",
165165
)
166166
return neural_network
167167
@@ -174,12 +174,16 @@ That's not so bad. The `Normal` priors help regularize the weights. Usually we w
174174

175175
+++
176176

177-
### Variational Inference: Scaling model complexity
177+
## Variational Inference: Scaling model complexity
178178

179179
We could now just run a MCMC sampler like {class}`pymc.NUTS` which works pretty well in this case, but was already mentioned, this will become very slow as we scale our model up to deeper architectures with more layers.
180180

181181
Instead, we will use the {class}`pymc.ADVI` variational inference algorithm. This is much faster and will scale better. Note, that this is a mean-field approximation so we ignore correlations in the posterior.
182182

183+
### Mini-batch ADVI
184+
185+
While this simulated dataset is small enough to fit all at once, it would not scale to something big like ImageNet. In the model above, we have set up minibatches that will allow for scaling to larger data sets. Moreover, training on mini-batches of data (stochastic gradient descent) avoids local minima and can lead to faster convergence.
186+
183187
```{code-cell} ipython3
184188
%%time
185189
@@ -199,17 +203,38 @@ plt.xlabel("iteration");
199203
trace = approx.sample(draws=5000)
200204
```
201205

202-
Now that we trained our model, lets predict on the hold-out set using a posterior predictive check (PPC). We can use {func}`~pymc.sample_posterior_predictive` to generate new data (in this case class predictions) from the posterior (sampled from the variational estimation).
206+
Now that we trained our model, lets predict on the hold-out set using a posterior predictive check (PPC). We can use {func}`pymc.sample_posterior_predictive` to generate new data (in this case class predictions) from the posterior (sampled from the variational estimation).
207+
208+
To predict on the entire test set (and not just the minibatches) we need to create a new model object that removes the minibatches. Notice that we are using our fitted `trace` to sample from the posterior predictive distribution, using the posterior estimates from the original model. There is no new inference here, we are just using the same model and the same posterior estimates to generate predictions. The {class}`Flat` distribution is just a placeholder to make the model work; the actual values are sampled from the posterior.
203209

204210
```{code-cell} ipython3
205-
---
206-
jupyter:
207-
outputs_hidden: true
208-
---
209-
with neural_network:
210-
pm.set_data(new_data={"ann_input": X_test})
211-
ppc = pm.sample_posterior_predictive(trace)
212-
trace.extend(ppc)
211+
def sample_posterior_predictive(X_test, Y_test, trace, n_hidden=5):
212+
coords = {
213+
"hidden_layer_1": np.arange(n_hidden),
214+
"hidden_layer_2": np.arange(n_hidden),
215+
"train_cols": np.arange(X_test.shape[1]),
216+
"obs_id": np.arange(X_test.shape[0]),
217+
}
218+
with pm.Model(coords=coords):
219+
220+
ann_input = X_test
221+
ann_output = Y_test
222+
223+
weights_in_1 = pm.Flat("w_in_1", dims=("train_cols", "hidden_layer_1"))
224+
weights_1_2 = pm.Flat("w_1_2", dims=("hidden_layer_1", "hidden_layer_2"))
225+
weights_2_out = pm.Flat("w_2_out", dims="hidden_layer_2")
226+
227+
# Build neural-network using tanh activation function
228+
act_1 = pm.math.tanh(pm.math.dot(ann_input, weights_in_1))
229+
act_2 = pm.math.tanh(pm.math.dot(act_1, weights_1_2))
230+
act_out = pm.math.sigmoid(pm.math.dot(act_2, weights_2_out))
231+
232+
# Binary classification -> Bernoulli likelihood
233+
out = pm.Bernoulli("out", act_out, observed=ann_output)
234+
return pm.sample_posterior_predictive(trace)
235+
236+
237+
ppc = sample_posterior_predictive(X_test, Y_test, trace)
213238
```
214239

215240
We can average the predictions for each observation to estimate the underlying probability of class 1.
@@ -250,18 +275,7 @@ dummy_out = np.ones(grid_2d.shape[0], dtype=np.int8)
250275
```
251276

252277
```{code-cell} ipython3
253-
---
254-
jupyter:
255-
outputs_hidden: true
256-
---
257-
coords_eval = {
258-
"train_cols": np.arange(grid_2d.shape[1]),
259-
"obs_id": np.arange(grid_2d.shape[0]),
260-
}
261-
262-
with neural_network:
263-
pm.set_data(new_data={"ann_input": grid_2d, "ann_output": dummy_out}, coords=coords_eval)
264-
ppc = pm.sample_posterior_predictive(trace)
278+
ppc = sample_posterior_predictive(grid_2d, dummy_out, trace)
265279
```
266280

267281
```{code-cell} ipython3
@@ -304,27 +318,6 @@ We can see that very close to the decision boundary, our uncertainty as to which
304318

305319
+++
306320

307-
## Mini-batch ADVI
308-
309-
So far, we have trained our model on all data at once. Obviously this won't scale to something like ImageNet. Moreover, training on mini-batches of data (stochastic gradient descent) avoids local minima and can lead to faster convergence.
310-
311-
Fortunately, ADVI can be run on mini-batches as well. It just requires some setting up:
312-
313-
```{code-cell} ipython3
314-
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
315-
neural_network_minibatch = construct_nn(minibatch_x, minibatch_y)
316-
with neural_network_minibatch:
317-
approx = pm.fit(40000, method=pm.ADVI())
318-
```
319-
320-
```{code-cell} ipython3
321-
plt.plot(approx.hist)
322-
plt.ylabel("ELBO")
323-
plt.xlabel("iteration");
324-
```
325-
326-
As you can see, mini-batch ADVI's running time is much lower. It also seems to converge faster.
327-
328321
For fun, we can also look at the trace. The point is that we also get uncertainty of our Neural Network weights.
329322

330323
```{code-cell} ipython3
@@ -352,6 +345,7 @@ You might argue that the above network isn't really deep, but note that we could
352345
- This notebook was originally authored as a [blog post](https://twiecki.github.io/blog/2016/06/01/bayesian-deep-learning/) by Thomas Wiecki in 2016
353346
- Updated by Chris Fonnesbeck for PyMC v4 in 2022
354347
- Updated by Oriol Abril-Pla and Earl Bellinger in 2023
348+
- Updated by Chris Fonnesbeck in 2024
355349

356350
## Watermark
357351

0 commit comments

Comments
 (0)