11"""
22Predicting next character with RNN using Tensorflow.
33Suppose we decide to predict the 11'th character using the last 10 characters.
4-
54Say we have a short story excerpt:
65"before the midnight..."
7-
86The tensor fed into tensorflow's rnn should be shaped like this [some nr of rows, 10]:
97[
108[b, e, f, o, r, e, , t, h, e],
119[e, f, o, r, e, , t, h, e, ],
1210[f, o, r, e, , t, h, e, , m],
1311[o, r, e, , t, h, e, , m, i],
1412[r, e, , t, h, e, , m, i, d], ... ]
15-
1613The output should be like this (the character we want to predict):
1714[
1815[ ],
1916[m],
2017[i],
2118[d],
2219[n], ... ]
23-
2420Roughly speaking, rnn does this process: output(t) = input_column(t) * w1 + hidden_state(t-1) * w2.
2521So, before returning the final output (i.e. the 11th character) the net will repeat this process 10 times.
26-
2722Of course, firstly, characters must be encoded into some kind of numerical representation.
2823After feeding this data into rnn, tensorflow does all the work. The rest is self-explanatory (well, kind of).
29-
3024Helpful source: https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537
3125"""
3226
3630from urllib .request import urlopen
3731
3832
39- def download_data (url = 'https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt' ):
33+ def download_data (
34+ url = 'https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt' ):
4035 # Short story e.g.: http://textfiles.com/stories/aircon.txt
4136 # Longer story e.g.: https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt
4237 text = str (urlopen (url ).read ())
43- text = text .replace ('\\ r' , '' )\
44- .replace ('\\ n' , '' )\
45- .replace ('\\ \' ' , '' )\
46- .replace ('\\ xe2\\ x99\\ xa6' , '' )\
47- .replace ('\\ xe2\\ x80\\ x94' , '' )
38+ text = text .replace ('\\ r' , '' ) \
39+ .replace ('\\ n' , '' ) \
40+ .replace ('\\ \' ' , '' ) \
41+ .replace ('\\ xe2\\ x99\\ xa6' , '' ) \
42+ .replace ('\\ xe2\\ x80\\ x94' , '' )
4843 return text
4944
5045
5146def get_dicts (text ):
52-
5347 """
5448 Returns a tuple of three objects:
5549 dictionary is a dictionary that contains all unique characters in given text (keys) and their ids (values)
5650 reverse_dictionary is a dictionary that contains character's ids (as keys) and characters (as values)
5751 chars is text converted into a list, where each element is single character.
5852 """
59-
53+
6054 chars = list (text ) # splits strings into chars and puts it into a list
6155 # chars = ''.join(char for char in text).split() # splits string into swords and stores it in a list
6256
6357 dictionary , reverse_dictionary = {}, {}
6458 for id , char in enumerate (set (chars )):
6559 dictionary [char ] = id
6660 reverse_dictionary [id ] = char
67-
61+
6862 return dictionary , reverse_dictionary , chars
6963
7064
7165def get_data (chars , dictionary , time_steps ):
72-
7366 """
7467 Returns data ready to be fed into neural net:
7568 x_data contains all sequences of characters (not chars, but their ids!). Single row corresponds to single sequence.
7669 y_data contains the id of next character in a sequence
7770 """
78-
71+
7972 x_data = np .zeros (shape = (len (chars ) - time_steps , time_steps ))
8073 y_data = np .zeros (shape = (len (chars ) - time_steps , len (set (chars ))))
8174
@@ -86,15 +79,14 @@ def get_data(chars, dictionary, time_steps):
8679 return x_data , y_data
8780
8881
89- def forward_prop (x , w , n_hidden ):
90-
82+ def forward_prop (x , w , n_hidden , drop ):
9183 """
9284 RNN with tanh activation in hidden layers and softmax activation in the last layer.
9385 Number of elements in n_hidden correspond to layers, each number corresponds to number of neurons in a layer.
9486 tf.contrib.rnn.static_rnn create weights and biases automatically, so there is no need to initiate it manually
9587 to follow things up, you can check all the tf variables by tf.get_collection('variables')
9688 """
97-
89+
9890 # split the data to time_steps columns, to recure one column by another
9991 x_split = tf .split (x , time_steps , 1 )
10092
@@ -105,7 +97,7 @@ def forward_prop(x, w, n_hidden):
10597
10698 # create the net and add dropout
10799 lstm_cell = tf .contrib .rnn .MultiRNNCell (stacked_lstm_cells )
108- lstm_cell_with_dropout = tf .contrib .rnn .DropoutWrapper (lstm_cell , output_keep_prob = 0.9 )
100+ lstm_cell_with_dropout = tf .contrib .rnn .DropoutWrapper (lstm_cell , output_keep_prob = drop [ 0 ] )
109101
110102 # forwawrd propagate
111103 outputs , state = tf .contrib .rnn .static_rnn (lstm_cell_with_dropout , x_split , dtype = tf .float32 )
@@ -123,7 +115,6 @@ def get_mini_batch(x, y, batch_size):
123115
124116
125117def generate_new_text (txt , print_length , new_line , dictionary , reverse_dictionary ):
126-
127118 """
128119 Generates text by predicting next character.
129120 Function arguments:
@@ -141,13 +132,13 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
141132 x_data_sample [:, id ] = dictionary [char ]
142133
143134 # print the text given as argument
144- print (txt )
135+ print (txt , end = '' )
145136
146137 # predict next char, print, use predicted char to predict next and so on
147138 txt_length = 1
148139 for _ in range (print_length ):
149- next_char_id = np .argmax ( sess .run ([y_ ], feed_dict = {x : x_data_sample ,})[0 ], axis = 1 )
150- next_char = reverse_dictionary [next_char_id [ 0 ] ]
140+ next_char_id = np .random . choice ( 74 , p = sess .run ([y_ ], feed_dict = {x : x_data_sample , drop : [ 1.0 ] })[0 ]. ravel () )
141+ next_char = reverse_dictionary [next_char_id ]
151142 x_data_sample = np .delete (x_data_sample , 0 , axis = 1 )
152143 x_data_sample = np .insert (x_data_sample , len (x_data_sample [0 ]), next_char_id , axis = 1 )
153144
@@ -164,6 +155,7 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
164155batch_size = 250
165156time_steps = 40 # size of sequence of chars
166157learning_rate = 1e-3
158+ dropout = 0.9
167159
168160# download and prepare data, initiate weights
169161text = download_data ()
@@ -173,41 +165,47 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
173165# initiate tf placeholders
174166x = tf .placeholder (tf .float32 , [None , time_steps ])
175167y = tf .placeholder (tf .float32 , [None , len (dictionary )])
168+ drop = tf .placeholder (tf .float32 , [1 ])
176169
177170# create other tf objects
178- w = tf .Variable (tf .random_normal ([n_hidden [- 1 ], len (dictionary )]), dtype = tf .float32 ) # last layer weights
179- logits , y_ = forward_prop (x , w , n_hidden )
171+ w = tf .Variable (tf .random_normal ([n_hidden [- 1 ], len (dictionary )]), dtype = tf .float32 ) # last layer weights
172+ logits , y_ = forward_prop (x , w , n_hidden , drop )
180173cost = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (logits = logits , labels = y ))
181174optimizer = tf .train .AdamOptimizer (learning_rate ).minimize (cost )
182175accuracy = tf .reduce_mean (tf .cast (tf .equal (tf .argmax (y_ , axis = 1 ), tf .argmax (y , axis = 1 )), tf .float32 ))
176+ saver = tf .train .Saver ()
183177
184178# initiate tf session
185179init = tf .global_variables_initializer ()
186180sess = tf .Session ()
187181sess .run (init )
188182
189183# initiate new training session
190- accuracy_hist = []
184+ cost_hist = []
191185iter = 0
192186
193187# training loop
194188while True :
189+ iter += 1
195190
196191 # get mini batch and train
197192 x_batch , y_batch = get_mini_batch (x_data , y_data , batch_size )
198- _ = sess .run ([optimizer ], feed_dict = {x : x_batch , y : y_batch })
199-
200- # other stuff
201- iter += 1
193+ _ = sess .run ([optimizer ], feed_dict = {x : x_batch , y : y_batch , drop : [dropout ]})
202194
203195 # plot and print
204- if iter % 50 == 0 :
205- error , acc = sess .run ([cost , accuracy ], feed_dict = {x : x_batch , y : y_batch })
206- accuracy_hist .append (acc )
207- pl .plot (accuracy_hist ); pl .pause (1e-99 )
208- print ('Cost: %.2f' % error )
209-
210- # generate new text by giving rnn something to start
211- starting_txt = 'Ive had the body moved to our little mor'
212- generate_new_text (txt = starting_txt , print_length = 1000 , new_line = 100 , dictionary = dictionary , reverse_dictionary = reverse_dictionary )
213-
196+ if iter % 100 == 0 :
197+ cost_ , acc = sess .run ([cost , accuracy ], feed_dict = {x : x_batch , y : y_batch , drop : [1.0 ]})
198+ cost_hist .append (cost_ )
199+ pl .cla ()
200+ pl .plot (cost_hist )
201+ pl .pause (1e-99 )
202+
203+ # generate new text by giving rnn something to start
204+ if iter % 500 == 0 or iter == 1 :
205+ starting_txt = 'Ive had the body moved to our little mor'
206+ generate_new_text (txt = starting_txt , print_length = 100 , new_line = 100 , dictionary = dictionary ,
207+ reverse_dictionary = reverse_dictionary )
208+
209+ # save
210+ # if iter % 500 == 0:
211+ # save_path = saver.save(sess, "model.ckpt")
0 commit comments