-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBase.py
123 lines (107 loc) · 4.88 KB
/
Base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import print_function
import datetime
import sys
import time
import numpy as np
import tensorflow as tf
from tqdm import tqdm
class Base(object):
def __init__(self, params):
self.params = params
self.save_dir = params.save_dir
self.starttime_init = time.time()
with tf.variable_scope('DMN'):
print('Building DMN...')
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.build()
self.saver = tf.train.Saver(tf.trainable_variables())
for var in tf.trainable_variables():
tf.summary.histogram(var.name, var)
self.merged_summary_op = tf.summary.merge_all()
self.merged_summary_op_b_loss = tf.summary.merge_all(key='b_stuff')
self.merged_summary_op_n_loss = tf.summary.merge_all(key='n_stuff')
self.merged_summary_op_m_loss = tf.summary.merge_all(key='m_stuff')
self.merged_summary_op_c_loss = tf.summary.merge_all(key='c_stuff')
def build(self):
raise NotImplementedError()
def get_feed_dict(self, batch, type, training):
if type == 'b':
(_, Is, Xs, Qs, As, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _) = batch
type = np.zeros_like(As)
answer = self.answer_b
elif type == 'n':
(_, _, _, _, _, _, Is, Xs, Qs, As, _, _, _, _, _, _, _, _, _, _) = batch
type = np.ones_like(As)
answer = self.answer_n
elif type == 'm':
(_, _, _, _, _, _, _, _, _, _, _, Is, Xs, Qs, As, _, _, _, _, _) = batch
type = np.repeat(2, len(As))
answer = self.answer_m
else:
(_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, Is, Xs, Qs, As) = batch
type = np.repeat(3, len(As))
answer = self.answer_c
return {self.training: training, self.input: Xs, self.question: Qs, self.type: type, answer: As}
def train_batch(self, sess, batch, sum_writer):
for (type, gradient_descent, summary_op_for_that_type) in [
('b', self.gradient_descent_b, self.merged_summary_op_b_loss),
('n', self.gradient_descent_n, self.merged_summary_op_n_loss),
('m', self.gradient_descent_m, self.merged_summary_op_m_loss),
('c', self.gradient_descent_c, self.merged_summary_op_c_loss)]:
feed_dict = self.get_feed_dict(batch, type, True)
if len(feed_dict[self.type]) > 0:
_, global_step, summary_all, specialized_summary = sess.run(
[gradient_descent, self.global_step, self.merged_summary_op, summary_op_for_that_type],
feed_dict=feed_dict)
sum_writer.add_summary(summary_all, global_step=global_step)
sum_writer.add_summary(specialized_summary, global_step=global_step)
def test_batch(self, sess, batch):
ret_list = []
for (type, predicts, accuracy) in \
[('b', self.predicts_b, self.accuracy_b), ('n', self.predicts_n, self.accuracy_n),
('m', self.predicts_m, self.accuracy_m), ('c', self.predicts_c, self.accuracy_c)]:
feed_dict = self.get_feed_dict(batch, type, False)
if len(feed_dict[self.type]) > 0:
ret_list += sess.run([self.predicts_t, predicts, accuracy], feed_dict=feed_dict)
else:
# accuracy -1.0 means empty test data
ret_list += [[], [], -1.0]
return ret_list
def train(self, sess, train_data, val_data, sum_writer):
for epoch in tqdm(range(self.params.num_epochs), desc='Epoch', maxinterval=10000, ncols=100):
for step in tqdm(range(self.params.num_steps), desc='Step', maxinterval=10000, ncols=100):
batch = train_data.next_batch()
self.train_batch(sess, batch, sum_writer)
self.eval(sess, val_data)
print('Training complete.')
def eval(self, sess, eval_data):
batch = eval_data.next_batch()
predicts_tb, predicts_b, accuracy_b, predicts_tn, predicts_n, accuracy_n, \
predicts_tm, predicts_m, accuracy_m, predicts_tc, predicts_c, accuracy_c = self.test_batch(sess, batch)
(Anns_b, Is_b, _, _, _, Anns_n, Is_n, _, _, _, Anns_m, Is_m, _, _, _, Anns_c, Is_c, _, _, _) = batch
for predict, Ann, I in zip(predicts_b, Anns_b, Is_b):
# eval_data.visualize(Ann, I)
tqdm.write('Predicted answer: %s' % ('yes' if predict == 1 else 'no'))
tqdm.write('Accuracy (yes/no): %f' % accuracy_b)
for predict, Ann, I in zip(predicts_b, Anns_n, Is_n):
# eval_data.visualize(Ann, I)
tqdm.write('Predicted answer: %d' % (predict))
tqdm.write('Accuracy (number): %f' % accuracy_n)
for predict, Ann, I in zip(predicts_c, Anns_c, Is_c):
# eval_data.visualize(Ann, I)
tqdm.write('Predicted answer: %s' % eval_data.index_to_color(predict))
tqdm.write('Accuracy (color): %f' % accuracy_c)
for predict, Ann, I in zip(predicts_m, Anns_m, Is_m):
# eval_data.visualize(Ann, I)
tqdm.write('Predicted answer: %s' % eval_data.index_to_word(predict))
tqdm.write('Accuracy (other): %f' % accuracy_m)
def save(self, sess):
print('Saving model to %s' % self.save_dir)
self.saver.save(sess, self.save_dir, self.global_step)
def load(self, sess):
print('Loading model ...')
checkpoint = tf.train.get_checkpoint_state(self.save_dir)
if checkpoint is None:
print('Error: No saved model found')
sys.exit(0)
self.saver.restore(sess, checkpoint.model_checkpoint_path)