forked from mollymr305/mnist-mc-dropout
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
48 lines (36 loc) · 1.31 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Helper functions."""
import cPickle as pickle
import gzip
import numpy as np
import os
from urllib import urlretrieve
def report(text, output_file):
f = open(output_file, 'a')
f.write('{}\n'.format(text))
f.close
def load_mnist_data():
mnist_filename = 'mnist.pkl.gz'
if not os.path.exists(mnist_filename):
print 'Downloading MNIST data ...'
url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
urlretrieve(url, mnist_filename)
train, val, test = pickle.load(gzip.open(mnist_filename, 'rb'))
# Training dataset
X_train, Y_train = train
X_train = X_train.reshape((-1, 1, 28, 28)).astype('float32')
Y_train = Y_train.astype('int32')
# Validation dataset
X_val, Y_val = val
X_val = X_val.reshape((-1, 1, 28, 28)).astype('float32')
Y_val = Y_val.astype('int32')
# Test dataset
X_test, Y_test = test
X_test = X_test.reshape((-1, 1, 28, 28)).astype('float32')
Y_test = Y_test.astype('int32')
return X_train, Y_train, X_val, Y_val, X_test, Y_test
def generate_batches(data, target, batch_size=500, stochastic=True):
idx = np.arange(len(data))
np.random.shuffle(idx) if stochastic else idx
for k in xrange(0, len(data), batch_size):
sample = idx[slice(k, k + batch_size)]
yield data[sample], target[sample]