-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
51 lines (38 loc) · 1.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import sys, os
import pickle
import tensorflow as tf
from collections import namedtuple
from trainer.Config import Config
from trainer.Config import Value, FCHidden
from networkbuilder.FFNetworkBuilder import FFNetworkBuilder
from hiddenbuilder.fc.FFHiddenBuilder import FFHiddenBuilder
from trainer.Trainer import Trainer, TrainerConfig
from dataset.MNIST import MNIST
# from tensorflow.examples.tutorials.mnist import input_data
# mnist = input_data.read_data_sets("/tmp/data")
BATCH_SIZE = 10
LEARNING_RATE = 0.0001
EPOCHS = 100
KEEP_PROB = 0.05
DISPLAY_STEP = 500
def main():
"""
Testing the feedforward framework on the mnist dataset.
"""
dataset = MNIST(BATCH_SIZE)
inputs = Value(type=tf.float32, shape=(None, 784), cls=None)
targets = Value(type=tf.int64, shape=(None), cls=10)
fc_hidden = FCHidden(weights=[300, 150])
config = Config(inputs, targets, fc_hidden, LEARNING_RATE)
network_builder = FFNetworkBuilder(config)
hidden_builder = FFHiddenBuilder(network_builder)
_ = network_builder.build_network(hidden_builder)
train_config = TrainerConfig(
epochs = EPOCHS, display_after = DISPLAY_STEP,
keep_prob = KEEP_PROB,checkpoint_path=None,
summary_path=None
)
trainer = Trainer(network_builder, train_config)
trainer.train(dataset)
if __name__ == '__main__':
main()