|
3 | 3 |
|
4 | 4 | import time
|
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import tensorflow as tf
|
7 | 8 |
|
8 | 9 | import tensorlayer as tl
|
| 10 | +from tensorlayer.layers import (BatchNorm, BinaryConv2d, BinaryDense, Flatten, Input, MaxPool2d, Sign) |
| 11 | +from tensorlayer.models import Model |
9 | 12 |
|
10 |
| -tf.logging.set_verbosity(tf.logging.DEBUG) |
11 | 13 | tl.logging.set_verbosity(tl.logging.DEBUG)
|
12 | 14 |
|
13 | 15 | X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
|
14 |
| -# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) |
15 |
| - |
16 |
| -sess = tf.InteractiveSession() |
17 | 16 |
|
18 | 17 | batch_size = 128
|
19 | 18 |
|
20 |
| -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) |
21 |
| -y_ = tf.placeholder(tf.int64, shape=[batch_size]) |
22 |
| - |
23 | 19 |
|
24 |
| -def model(x, is_train=True, reuse=False): |
| 20 | +def model(inputs_shape, n_class=10): |
25 | 21 | # In BNN, all the layers inputs are binary, with the exception of the first layer.
|
26 | 22 | # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py
|
27 |
| - with tf.variable_scope("binarynet", reuse=reuse): |
28 |
| - net = tl.layers.InputLayer(x, name='input') |
29 |
| - net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') |
30 |
| - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') |
31 |
| - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') |
32 |
| - |
33 |
| - net = tl.layers.SignLayer(net) |
34 |
| - net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2') |
35 |
| - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') |
36 |
| - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') |
37 |
| - |
38 |
| - net = tl.layers.FlattenLayer(net) |
39 |
| - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') |
40 |
| - net = tl.layers.SignLayer(net) |
41 |
| - net = tl.layers.BinaryDenseLayer(net, 256, b_init=None, name='dense') |
42 |
| - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') |
43 |
| - |
44 |
| - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') |
45 |
| - net = tl.layers.SignLayer(net) |
46 |
| - net = tl.layers.BinaryDenseLayer(net, 10, b_init=None, name='bout') |
47 |
| - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') |
| 23 | + net_in = Input(inputs_shape, name='input') |
| 24 | + net = BinaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(net_in) |
| 25 | + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net) |
| 26 | + net = BatchNorm(act=tl.act.htanh, name='bn1')(net) |
| 27 | + |
| 28 | + net = Sign("sign1")(net) |
| 29 | + net = BinaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net) |
| 30 | + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net) |
| 31 | + net = BatchNorm(act=tl.act.htanh, name='bn2')(net) |
| 32 | + |
| 33 | + net = Flatten('ft')(net) |
| 34 | + net = Sign("sign2")(net) |
| 35 | + net = BinaryDense(256, b_init=None, name='dense')(net) |
| 36 | + net = BatchNorm(act=tl.act.htanh, name='bn3')(net) |
| 37 | + |
| 38 | + net = Sign("sign3")(net) |
| 39 | + net = BinaryDense(10, b_init=None, name='bout')(net) |
| 40 | + net = BatchNorm(name='bno')(net) |
| 41 | + net = Model(inputs=net_in, outputs=net, name='binarynet') |
48 | 42 | return net
|
49 | 43 |
|
50 | 44 |
|
51 |
| -# define inferences |
52 |
| -net_train = model(x, is_train=True, reuse=False) |
53 |
| -net_test = model(x, is_train=False, reuse=True) |
54 |
| - |
55 |
| -# cost for training |
56 |
| -y = net_train.outputs |
57 |
| -cost = tl.cost.cross_entropy(y, y_, name='xentropy') |
| 45 | +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): |
| 46 | + with tf.GradientTape() as tape: |
| 47 | + y_pred = network(X_batch) |
| 48 | + _loss = cost(y_pred, y_batch) |
| 49 | + grad = tape.gradient(_loss, network.trainable_weights) |
| 50 | + train_op.apply_gradients(zip(grad, network.trainable_weights)) |
| 51 | + if acc is not None: |
| 52 | + _acc = acc(y_pred, y_batch) |
| 53 | + return _loss, _acc |
| 54 | + else: |
| 55 | + return _loss, None |
58 | 56 |
|
59 |
| -# cost and accuracy for evalution |
60 |
| -y2 = net_test.outputs |
61 |
| -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') |
62 |
| -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) |
63 |
| -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
64 | 57 |
|
65 |
| -# define the optimizer |
66 |
| -train_params = tl.layers.get_variables_with_name('binarynet', True, True) |
67 |
| -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) |
| 58 | +def accuracy(_logits, y_batch): |
| 59 | + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) |
68 | 60 |
|
69 |
| -# initialize all variables in the session |
70 |
| -sess.run(tf.global_variables_initializer()) |
71 |
| - |
72 |
| -net_train.print_params() |
73 |
| -net_train.print_layers() |
74 | 61 |
|
75 | 62 | n_epoch = 200
|
76 | 63 | print_freq = 5
|
77 | 64 |
|
78 |
| -# print(sess.run(net_test.all_params)) # print real values of parameters |
| 65 | +net = model([None, 28, 28, 1]) |
| 66 | +train_op = tf.optimizers.Adam(learning_rate=0.0001) |
| 67 | +cost = tl.cost.cross_entropy |
79 | 68 |
|
80 | 69 | for epoch in range(n_epoch):
|
81 | 70 | start_time = time.time()
|
| 71 | + train_loss, train_acc, n_batch = 0, 0, 0 |
| 72 | + net.train() |
| 73 | + |
82 | 74 | for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
|
83 |
| - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) |
| 75 | + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) |
| 76 | + train_loss += _loss |
| 77 | + train_acc += acc |
| 78 | + n_batch += 1 |
| 79 | + |
| 80 | + # print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) |
| 81 | + # print(" train loss: %f" % (train_loss / n_batch)) |
| 82 | + # print(" train acc: %f" % (train_acc / n_batch)) |
84 | 83 |
|
85 | 84 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
|
86 | 85 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
|
87 |
| - train_loss, train_acc, n_batch = 0, 0, 0 |
88 |
| - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): |
89 |
| - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) |
90 |
| - train_loss += err |
91 |
| - train_acc += ac |
92 |
| - n_batch += 1 |
93 | 86 | print(" train loss: %f" % (train_loss / n_batch))
|
94 | 87 | print(" train acc: %f" % (train_acc / n_batch))
|
95 |
| - val_loss, val_acc, n_batch = 0, 0, 0 |
| 88 | + val_loss, val_acc, val_batch = 0, 0, 0 |
| 89 | + net.eval() |
96 | 90 | for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True):
|
97 |
| - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) |
98 |
| - val_loss += err |
99 |
| - val_acc += ac |
100 |
| - n_batch += 1 |
101 |
| - print(" val loss: %f" % (val_loss / n_batch)) |
102 |
| - print(" val acc: %f" % (val_acc / n_batch)) |
103 |
| - |
104 |
| -print('Evaluation') |
105 |
| -test_loss, test_acc, n_batch = 0, 0, 0 |
| 91 | + _logits = net(X_val_a) |
| 92 | + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') |
| 93 | + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) |
| 94 | + val_batch += 1 |
| 95 | + print(" val loss: {}".format(val_loss / val_batch)) |
| 96 | + print(" val acc: {}".format(val_acc / val_batch)) |
| 97 | + |
| 98 | +net.test() |
| 99 | +test_loss, test_acc, n_test_batch = 0, 0, 0 |
106 | 100 | for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True):
|
107 |
| - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) |
108 |
| - test_loss += err |
109 |
| - test_acc += ac |
110 |
| - n_batch += 1 |
111 |
| -print(" test loss: %f" % (test_loss / n_batch)) |
112 |
| -print(" test acc: %f" % (test_acc / n_batch)) |
| 101 | + _logits = net(X_test_a) |
| 102 | + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') |
| 103 | + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) |
| 104 | + n_test_batch += 1 |
| 105 | +print(" test loss: %f" % (test_loss / n_test_batch)) |
| 106 | +print(" test acc: %f" % (test_acc / n_test_batch)) |
0 commit comments