1
1
"""
2
2
Predicting next character with RNN using Tensorflow.
3
3
Suppose we decide to predict the 11'th character using the last 10 characters.
4
-
5
4
Say we have a short story excerpt:
6
5
"before the midnight..."
7
-
8
6
The tensor fed into tensorflow's rnn should be shaped like this [some nr of rows, 10]:
9
7
[
10
8
[b, e, f, o, r, e, , t, h, e],
11
9
[e, f, o, r, e, , t, h, e, ],
12
10
[f, o, r, e, , t, h, e, , m],
13
11
[o, r, e, , t, h, e, , m, i],
14
12
[r, e, , t, h, e, , m, i, d], ... ]
15
-
16
13
The output should be like this (the character we want to predict):
17
14
[
18
15
[ ],
19
16
[m],
20
17
[i],
21
18
[d],
22
19
[n], ... ]
23
-
24
20
Roughly speaking, rnn does this process: output(t) = input_column(t) * w1 + hidden_state(t-1) * w2.
25
21
So, before returning the final output (i.e. the 11th character) the net will repeat this process 10 times.
26
-
27
22
Of course, firstly, characters must be encoded into some kind of numerical representation.
28
23
After feeding this data into rnn, tensorflow does all the work. The rest is self-explanatory (well, kind of).
29
-
30
24
Helpful source: https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537
31
25
"""
32
26
36
30
from urllib .request import urlopen
37
31
38
32
39
- def download_data (url = 'https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt' ):
33
+ def download_data (
34
+ url = 'https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt' ):
40
35
# Short story e.g.: http://textfiles.com/stories/aircon.txt
41
36
# Longer story e.g.: https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt
42
37
text = str (urlopen (url ).read ())
43
- text = text .replace ('\\ r' , '' )\
44
- .replace ('\\ n' , '' )\
45
- .replace ('\\ \' ' , '' )\
46
- .replace ('\\ xe2\\ x99\\ xa6' , '' )\
47
- .replace ('\\ xe2\\ x80\\ x94' , '' )
38
+ text = text .replace ('\\ r' , '' ) \
39
+ .replace ('\\ n' , '' ) \
40
+ .replace ('\\ \' ' , '' ) \
41
+ .replace ('\\ xe2\\ x99\\ xa6' , '' ) \
42
+ .replace ('\\ xe2\\ x80\\ x94' , '' )
48
43
return text
49
44
50
45
51
46
def get_dicts (text ):
52
-
53
47
"""
54
48
Returns a tuple of three objects:
55
49
dictionary is a dictionary that contains all unique characters in given text (keys) and their ids (values)
56
50
reverse_dictionary is a dictionary that contains character's ids (as keys) and characters (as values)
57
51
chars is text converted into a list, where each element is single character.
58
52
"""
59
-
53
+
60
54
chars = list (text ) # splits strings into chars and puts it into a list
61
55
# chars = ''.join(char for char in text).split() # splits string into swords and stores it in a list
62
56
63
57
dictionary , reverse_dictionary = {}, {}
64
58
for id , char in enumerate (set (chars )):
65
59
dictionary [char ] = id
66
60
reverse_dictionary [id ] = char
67
-
61
+
68
62
return dictionary , reverse_dictionary , chars
69
63
70
64
71
65
def get_data (chars , dictionary , time_steps ):
72
-
73
66
"""
74
67
Returns data ready to be fed into neural net:
75
68
x_data contains all sequences of characters (not chars, but their ids!). Single row corresponds to single sequence.
76
69
y_data contains the id of next character in a sequence
77
70
"""
78
-
71
+
79
72
x_data = np .zeros (shape = (len (chars ) - time_steps , time_steps ))
80
73
y_data = np .zeros (shape = (len (chars ) - time_steps , len (set (chars ))))
81
74
@@ -86,15 +79,14 @@ def get_data(chars, dictionary, time_steps):
86
79
return x_data , y_data
87
80
88
81
89
- def forward_prop (x , w , n_hidden ):
90
-
82
+ def forward_prop (x , w , n_hidden , drop ):
91
83
"""
92
84
RNN with tanh activation in hidden layers and softmax activation in the last layer.
93
85
Number of elements in n_hidden correspond to layers, each number corresponds to number of neurons in a layer.
94
86
tf.contrib.rnn.static_rnn create weights and biases automatically, so there is no need to initiate it manually
95
87
to follow things up, you can check all the tf variables by tf.get_collection('variables')
96
88
"""
97
-
89
+
98
90
# split the data to time_steps columns, to recure one column by another
99
91
x_split = tf .split (x , time_steps , 1 )
100
92
@@ -105,7 +97,7 @@ def forward_prop(x, w, n_hidden):
105
97
106
98
# create the net and add dropout
107
99
lstm_cell = tf .contrib .rnn .MultiRNNCell (stacked_lstm_cells )
108
- lstm_cell_with_dropout = tf .contrib .rnn .DropoutWrapper (lstm_cell , output_keep_prob = 0.9 )
100
+ lstm_cell_with_dropout = tf .contrib .rnn .DropoutWrapper (lstm_cell , output_keep_prob = drop [ 0 ] )
109
101
110
102
# forwawrd propagate
111
103
outputs , state = tf .contrib .rnn .static_rnn (lstm_cell_with_dropout , x_split , dtype = tf .float32 )
@@ -123,7 +115,6 @@ def get_mini_batch(x, y, batch_size):
123
115
124
116
125
117
def generate_new_text (txt , print_length , new_line , dictionary , reverse_dictionary ):
126
-
127
118
"""
128
119
Generates text by predicting next character.
129
120
Function arguments:
@@ -141,13 +132,13 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
141
132
x_data_sample [:, id ] = dictionary [char ]
142
133
143
134
# print the text given as argument
144
- print (txt )
135
+ print (txt , end = '' )
145
136
146
137
# predict next char, print, use predicted char to predict next and so on
147
138
txt_length = 1
148
139
for _ in range (print_length ):
149
- next_char_id = np .argmax ( sess .run ([y_ ], feed_dict = {x : x_data_sample ,})[0 ], axis = 1 )
150
- next_char = reverse_dictionary [next_char_id [ 0 ] ]
140
+ next_char_id = np .random . choice ( 74 , p = sess .run ([y_ ], feed_dict = {x : x_data_sample , drop : [ 1.0 ] })[0 ]. ravel () )
141
+ next_char = reverse_dictionary [next_char_id ]
151
142
x_data_sample = np .delete (x_data_sample , 0 , axis = 1 )
152
143
x_data_sample = np .insert (x_data_sample , len (x_data_sample [0 ]), next_char_id , axis = 1 )
153
144
@@ -164,6 +155,7 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
164
155
batch_size = 250
165
156
time_steps = 40 # size of sequence of chars
166
157
learning_rate = 1e-3
158
+ dropout = 0.9
167
159
168
160
# download and prepare data, initiate weights
169
161
text = download_data ()
@@ -173,41 +165,47 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
173
165
# initiate tf placeholders
174
166
x = tf .placeholder (tf .float32 , [None , time_steps ])
175
167
y = tf .placeholder (tf .float32 , [None , len (dictionary )])
168
+ drop = tf .placeholder (tf .float32 , [1 ])
176
169
177
170
# create other tf objects
178
- w = tf .Variable (tf .random_normal ([n_hidden [- 1 ], len (dictionary )]), dtype = tf .float32 ) # last layer weights
179
- logits , y_ = forward_prop (x , w , n_hidden )
171
+ w = tf .Variable (tf .random_normal ([n_hidden [- 1 ], len (dictionary )]), dtype = tf .float32 ) # last layer weights
172
+ logits , y_ = forward_prop (x , w , n_hidden , drop )
180
173
cost = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (logits = logits , labels = y ))
181
174
optimizer = tf .train .AdamOptimizer (learning_rate ).minimize (cost )
182
175
accuracy = tf .reduce_mean (tf .cast (tf .equal (tf .argmax (y_ , axis = 1 ), tf .argmax (y , axis = 1 )), tf .float32 ))
176
+ saver = tf .train .Saver ()
183
177
184
178
# initiate tf session
185
179
init = tf .global_variables_initializer ()
186
180
sess = tf .Session ()
187
181
sess .run (init )
188
182
189
183
# initiate new training session
190
- accuracy_hist = []
184
+ cost_hist = []
191
185
iter = 0
192
186
193
187
# training loop
194
188
while True :
189
+ iter += 1
195
190
196
191
# get mini batch and train
197
192
x_batch , y_batch = get_mini_batch (x_data , y_data , batch_size )
198
- _ = sess .run ([optimizer ], feed_dict = {x : x_batch , y : y_batch })
199
-
200
- # other stuff
201
- iter += 1
193
+ _ = sess .run ([optimizer ], feed_dict = {x : x_batch , y : y_batch , drop : [dropout ]})
202
194
203
195
# plot and print
204
- if iter % 50 == 0 :
205
- error , acc = sess .run ([cost , accuracy ], feed_dict = {x : x_batch , y : y_batch })
206
- accuracy_hist .append (acc )
207
- pl .plot (accuracy_hist ); pl .pause (1e-99 )
208
- print ('Cost: %.2f' % error )
209
-
210
- # generate new text by giving rnn something to start
211
- starting_txt = 'Ive had the body moved to our little mor'
212
- generate_new_text (txt = starting_txt , print_length = 1000 , new_line = 100 , dictionary = dictionary , reverse_dictionary = reverse_dictionary )
213
-
196
+ if iter % 100 == 0 :
197
+ cost_ , acc = sess .run ([cost , accuracy ], feed_dict = {x : x_batch , y : y_batch , drop : [1.0 ]})
198
+ cost_hist .append (cost_ )
199
+ pl .cla ()
200
+ pl .plot (cost_hist )
201
+ pl .pause (1e-99 )
202
+
203
+ # generate new text by giving rnn something to start
204
+ if iter % 500 == 0 or iter == 1 :
205
+ starting_txt = 'Ive had the body moved to our little mor'
206
+ generate_new_text (txt = starting_txt , print_length = 100 , new_line = 100 , dictionary = dictionary ,
207
+ reverse_dictionary = reverse_dictionary )
208
+
209
+ # save
210
+ # if iter % 500 == 0:
211
+ # save_path = saver.save(sess, "model.ckpt")
0 commit comments