-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathpointer_net.py
261 lines (242 loc) · 11.9 KB
/
pointer_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import tensorflow as tf
START_ID=0
PAD_ID=1
END_ID=2
class PointerWrapper(tf.contrib.seq2seq.AttentionWrapper):
"""Customized AttentionWrapper for PointerNet."""
def __init__(self,cell,attention_size,memory,initial_cell_state=None,name=None):
# In the paper, Bahdanau Attention Mechanism is used
# We want the scores rather than the probabilities of alignments
# Hence, we customize the probability_fn to return scores directly
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attention_size, memory, probability_fn=lambda x: x )
# According to the paper, no need to concatenate the input and attention
# Therefore, we make cell_input_fn to return input only
cell_input_fn=lambda input, attention: input
# Call super __init__
super(PointerWrapper, self).__init__(cell,
attention_mechanism=attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=cell_input_fn,
output_attention=True,
initial_cell_state=initial_cell_state,
name=name)
@property
def output_size(self):
return self.state_size.alignments
def call(self, inputs, state):
_, next_state = super(PointerWrapper, self).call(inputs, state)
return next_state.alignments, next_state
class PointerNet(object):
""" Pointer Net Model
This class implements a multi-layer Pointer Network
aimed to solve the Convex Hull problem. It is almost
the same as the model described in this paper:
https://arxiv.org/abs/1506.03134.
"""
def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence_len=7,
rnn_size=128, attention_size=128, num_layers=2, beam_width=2,
learning_rate=0.001, max_gradient_norm=5, forward_only=False):
"""Create the model.
Args:
batch_size: the size of batch during training
max_input_sequence_len: the maximum input length
max_output_sequence_len: the maximum output length
rnn_size: the size of each RNN hidden units
attention_size: the size of dimensions in attention mechanism
num_layers: the number of stacked RNN layers
beam_width: the width of beam search
learning_rate: the initial learning rate during training
max_gradient_norm: gradients will be clipped to maximally this norm.
forward_only: whether the model is forwarding only
"""
self.batch_size = batch_size
self.max_input_sequence_len = max_input_sequence_len
self.max_output_sequence_len = max_output_sequence_len
self.forward_only = forward_only
self.init_learning_rate = learning_rate
# Note we have three special tokens namely 'START', 'PAD' and 'END'
# Here the size of vocab need be added by 3
self.vocab_size = max_input_sequence_len+3
# Global step
self.global_step = tf.Variable(0, trainable=False)
# Choose LSTM Cell
cell = tf.contrib.rnn.LSTMCell
# Create placeholders
self.inputs = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_input_sequence_len,2], name="inputs")
self.outputs = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_output_sequence_len+1], name="outputs")
self.enc_input_weights = tf.placeholder(tf.int32,shape=[self.batch_size,self.max_input_sequence_len], name="enc_input_weights")
self.dec_input_weights = tf.placeholder(tf.int32,shape=[self.batch_size,self.max_output_sequence_len], name="dec_input_weights")
# Calculate the lengths
enc_input_lens=tf.reduce_sum(self.enc_input_weights,axis=1)
dec_input_lens=tf.reduce_sum(self.dec_input_weights,axis=1)
# Special token embedding
special_token_embedding = tf.get_variable("special_token_embedding", [3,2], tf.float32, tf.contrib.layers.xavier_initializer())
# Embedding_table
# Shape: [batch_size,vocab_size,features_size]
embedding_table = tf.concat([tf.tile(tf.expand_dims(special_token_embedding,0),[self.batch_size,1,1]), self.inputs],axis=1)
# Unstack embedding_table
# Shape: batch_size*[vocab_size,features_size]
embedding_table_list = tf.unstack(embedding_table, axis=0)
# Unstack outputs
# Shape: (max_output_sequence_len+1)*[batch_size]
outputs_list = tf.unstack(self.outputs, axis=1)
# targets
# Shape: [batch_size,max_output_sequence_len]
self.targets = tf.stack(outputs_list[1:],axis=1)
# decoder input ids
# Shape: batch_size*[max_output_sequence_len,1]
dec_input_ids = tf.unstack(tf.expand_dims(tf.stack(outputs_list[:-1],axis=1),2),axis=0)
# encoder input ids
# Shape: batch_size*[max_input_sequence_len+1,1]
enc_input_ids = [tf.expand_dims(tf.range(2,self.vocab_size),1)]*self.batch_size
# Look up encoder and decoder inputs
encoder_inputs = []
decoder_inputs = []
for i in range(self.batch_size):
encoder_inputs.append(tf.gather_nd(embedding_table_list[i], enc_input_ids[i]))
decoder_inputs.append(tf.gather_nd(embedding_table_list[i], dec_input_ids[i]))
# Shape: [batch_size,max_input_sequence_len+1,2]
encoder_inputs = tf.stack(encoder_inputs,axis=0)
# Shape: [batch_size,max_output_sequence_len,2]
decoder_inputs = tf.stack(decoder_inputs,axis=0)
# Stack encoder cells if needed
if num_layers > 1:
fw_enc_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers)])
bw_enc_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers)])
else:
fw_enc_cell = cell(rnn_size)
bw_enc_cell = cell(rnn_size)
# Tile inputs if forward only
if self.forward_only:
# Tile encoder_inputs and enc_input_lens
encoder_inputs = tf.contrib.seq2seq.tile_batch(encoder_inputs,beam_width)
enc_input_lens = tf.contrib.seq2seq.tile_batch(enc_input_lens,beam_width)
# Encode input to obtain memory for later queries
memory,_ = tf.nn.bidirectional_dynamic_rnn(fw_enc_cell, bw_enc_cell, encoder_inputs, enc_input_lens, dtype=tf.float32)
# Shape: [batch_size(*beam_width), max_input_sequence_len+1, 2*rnn_size]
memory = tf.concat(memory, 2)
# PointerWrapper
pointer_cell = PointerWrapper(cell(rnn_size), attention_size, memory)
# Stack decoder cells if needed
if num_layers > 1:
dec_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers-1)]+[pointer_cell])
else:
dec_cell = pointer_cell
# Different decoding scenario
if self.forward_only:
# Tile embedding_table
tile_embedding_table = tf.tile(tf.expand_dims(embedding_table,1),[1,beam_width,1,1])
# Customize embedding_lookup_fn
def embedding_lookup(ids):
# Note the output value of the decoder only ranges 0 to max_input_sequence_len
# while embedding_table contains two more tokens' values
# To get around this, shift ids
# Shape: [batch_size,beam_width]
ids = ids+2
# Shape: [batch_size,beam_width,vocab_size]
one_hot_ids = tf.cast(tf.one_hot(ids,self.vocab_size), dtype=tf.float32)
# Shape: [batch_size,beam_width,vocab_size,1]
one_hot_ids = tf.expand_dims(one_hot_ids,-1)
# Shape: [batch_size,beam_width,features_size]
next_inputs = tf.reduce_sum(one_hot_ids*tile_embedding_table, axis=2)
return next_inputs
# Do a little trick so that we can use 'BeamSearchDecoder'
shifted_START_ID = START_ID - 2
shifted_END_ID = END_ID - 2
# Beam Search Decoder
decoder = tf.contrib.seq2seq.BeamSearchDecoder(dec_cell, embedding_lookup,
tf.tile([shifted_START_ID],[self.batch_size]), shifted_END_ID,
dec_cell.zero_state(self.batch_size*beam_width,tf.float32), beam_width)
# Decode
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
# predicted_ids
# Shape: [batch_size, max_output_sequence_len, beam_width]
predicted_ids = outputs.predicted_ids
# Transpose predicted_ids
# Shape: [batch_size, beam_width, max_output_sequence_len]
self.predicted_ids = tf.transpose(predicted_ids,[0,2,1])
else:
# Get the maximum sequence length in current batch
cur_batch_max_len = tf.reduce_max(dec_input_lens)
# Training Helper
helper = tf.contrib.seq2seq.TrainingHelper(decoder_inputs, dec_input_lens)
# Basic Decoder
decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, helper, dec_cell.zero_state(self.batch_size,tf.float32))
# Decode
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True)
# logits
logits = outputs.rnn_output
# predicted_ids_with_logits
self.predicted_ids_with_logits=tf.nn.top_k(logits)
# Pad logits to the same shape as targets
logits = tf.concat([logits,tf.ones([self.batch_size,self.max_output_sequence_len-cur_batch_max_len,self.max_input_sequence_len+1])],axis=1)
# Subtract target values by 2
# because prediction output ranges from 0 to max_input_sequence_len+1
# while target values are from 0 to max_input_sequence_len + 3
self.shifted_targets = (self.targets - 2)*self.dec_input_weights
# Losses
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.shifted_targets, logits=logits)
# Total loss
self.loss = tf.reduce_sum(losses*tf.cast(self.dec_input_weights,tf.float32))/self.batch_size
# Get all trainable variables
parameters = tf.trainable_variables()
# Calculate gradients
gradients = tf.gradients(self.loss, parameters)
# Clip gradients
clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)
# Optimization
#optimizer = tf.train.GradientDescentOptimizer(self.init_learning_rate)
optimizer = tf.train.AdamOptimizer(self.init_learning_rate)
# Update operator
self.update = optimizer.apply_gradients(zip(clipped_gradients, parameters),global_step=self.global_step)
# Summarize
tf.summary.scalar('loss',self.loss)
for p in parameters:
tf.summary.histogram(p.op.name,p)
for p in gradients:
tf.summary.histogram(p.op.name,p)
# Summarize operator
self.summary_op = tf.summary.merge_all()
#DEBUG PART
self.debug_var = logits
#/DEBUG PART
# Saver
self.saver = tf.train.Saver(tf.global_variables())
def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weights=None):
"""Run a step of the model feeding the given inputs.
Args:
session: tensorflow session to use.
inputs: the point positions in 2D coordinate. shape: [batch_size,max_input_sequence_len,2]
enc_input_weights: the weights of encoder input points. shape: [batch_size,max_input_sequence_len]
outputs: the point indexes in inputs. shape: [batch_size,max_output_sequence_len+1]
dec_input_weights: the weights of decoder input points. shape: [batch_size,max_output_sequence_len]
Returns:
(training)
The summary
The total loss
The predicted ids with logits
The targets
The variable for debugging
(evaluation)
The predicted ids
"""
#Fill up inputs
input_feed = {}
input_feed[self.inputs] = inputs
input_feed[self.enc_input_weights] = enc_input_weights
if self.forward_only==False:
input_feed[self.outputs] = outputs
input_feed[self.dec_input_weights] = dec_input_weights
#Fill up outputs
if self.forward_only:
output_feed = [self.predicted_ids]
else:
output_feed = [self.update, self.summary_op, self.loss, self.predicted_ids_with_logits, self.shifted_targets, self.debug_var]
#Run step
outputs = session.run(output_feed, input_feed)
#Return
if self.forward_only:
return outputs[0]
else:
return outputs[1],outputs[2],outputs[3],outputs[4],outputs[5]