Skip to content

Commit c1da718

Browse files
authored
update
1 parent 10cb6f9 commit c1da718

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

RNN_predict_next_char.py

+40-42
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
11
"""
22
Predicting next character with RNN using Tensorflow.
33
Suppose we decide to predict the 11'th character using the last 10 characters.
4-
54
Say we have a short story excerpt:
65
"before the midnight..."
7-
86
The tensor fed into tensorflow's rnn should be shaped like this [some nr of rows, 10]:
97
[
108
[b, e, f, o, r, e, , t, h, e],
119
[e, f, o, r, e, , t, h, e, ],
1210
[f, o, r, e, , t, h, e, , m],
1311
[o, r, e, , t, h, e, , m, i],
1412
[r, e, , t, h, e, , m, i, d], ... ]
15-
1613
The output should be like this (the character we want to predict):
1714
[
1815
[ ],
1916
[m],
2017
[i],
2118
[d],
2219
[n], ... ]
23-
2420
Roughly speaking, rnn does this process: output(t) = input_column(t) * w1 + hidden_state(t-1) * w2.
2521
So, before returning the final output (i.e. the 11th character) the net will repeat this process 10 times.
26-
2722
Of course, firstly, characters must be encoded into some kind of numerical representation.
2823
After feeding this data into rnn, tensorflow does all the work. The rest is self-explanatory (well, kind of).
29-
3024
Helpful source: https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537
3125
"""
3226

@@ -36,46 +30,45 @@
3630
from urllib.request import urlopen
3731

3832

39-
def download_data(url='https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt'):
33+
def download_data(
34+
url='https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt'):
4035
# Short story e.g.: http://textfiles.com/stories/aircon.txt
4136
# Longer story e.g.: https://ia601603.us.archive.org/3/items/CamusAlbertTheStranger/CamusAlbert-TheStranger_djvu.txt
4237
text = str(urlopen(url).read())
43-
text = text.replace('\\r', '')\
44-
.replace('\\n', '')\
45-
.replace('\\\'', '')\
46-
.replace('\\xe2\\x99\\xa6', '')\
47-
.replace('\\xe2\\x80\\x94', '')
38+
text = text.replace('\\r', '') \
39+
.replace('\\n', '') \
40+
.replace('\\\'', '') \
41+
.replace('\\xe2\\x99\\xa6', '') \
42+
.replace('\\xe2\\x80\\x94', '')
4843
return text
4944

5045

5146
def get_dicts(text):
52-
5347
"""
5448
Returns a tuple of three objects:
5549
dictionary is a dictionary that contains all unique characters in given text (keys) and their ids (values)
5650
reverse_dictionary is a dictionary that contains character's ids (as keys) and characters (as values)
5751
chars is text converted into a list, where each element is single character.
5852
"""
59-
53+
6054
chars = list(text) # splits strings into chars and puts it into a list
6155
# chars = ''.join(char for char in text).split() # splits string into swords and stores it in a list
6256

6357
dictionary, reverse_dictionary = {}, {}
6458
for id, char in enumerate(set(chars)):
6559
dictionary[char] = id
6660
reverse_dictionary[id] = char
67-
61+
6862
return dictionary, reverse_dictionary, chars
6963

7064

7165
def get_data(chars, dictionary, time_steps):
72-
7366
"""
7467
Returns data ready to be fed into neural net:
7568
x_data contains all sequences of characters (not chars, but their ids!). Single row corresponds to single sequence.
7669
y_data contains the id of next character in a sequence
7770
"""
78-
71+
7972
x_data = np.zeros(shape=(len(chars) - time_steps, time_steps))
8073
y_data = np.zeros(shape=(len(chars) - time_steps, len(set(chars))))
8174

@@ -86,15 +79,14 @@ def get_data(chars, dictionary, time_steps):
8679
return x_data, y_data
8780

8881

89-
def forward_prop(x, w, n_hidden):
90-
82+
def forward_prop(x, w, n_hidden, drop):
9183
"""
9284
RNN with tanh activation in hidden layers and softmax activation in the last layer.
9385
Number of elements in n_hidden correspond to layers, each number corresponds to number of neurons in a layer.
9486
tf.contrib.rnn.static_rnn create weights and biases automatically, so there is no need to initiate it manually
9587
to follow things up, you can check all the tf variables by tf.get_collection('variables')
9688
"""
97-
89+
9890
# split the data to time_steps columns, to recure one column by another
9991
x_split = tf.split(x, time_steps, 1)
10092

@@ -105,7 +97,7 @@ def forward_prop(x, w, n_hidden):
10597

10698
# create the net and add dropout
10799
lstm_cell = tf.contrib.rnn.MultiRNNCell(stacked_lstm_cells)
108-
lstm_cell_with_dropout = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=0.9)
100+
lstm_cell_with_dropout = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=drop[0])
109101

110102
# forwawrd propagate
111103
outputs, state = tf.contrib.rnn.static_rnn(lstm_cell_with_dropout, x_split, dtype=tf.float32)
@@ -123,7 +115,6 @@ def get_mini_batch(x, y, batch_size):
123115

124116

125117
def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionary):
126-
127118
"""
128119
Generates text by predicting next character.
129120
Function arguments:
@@ -141,13 +132,13 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
141132
x_data_sample[:, id] = dictionary[char]
142133

143134
# print the text given as argument
144-
print(txt)
135+
print(txt, end='')
145136

146137
# predict next char, print, use predicted char to predict next and so on
147138
txt_length = 1
148139
for _ in range(print_length):
149-
next_char_id = np.argmax(sess.run([y_], feed_dict={x: x_data_sample,})[0], axis=1)
150-
next_char = reverse_dictionary[next_char_id[0]]
140+
next_char_id = np.random.choice(74, p=sess.run([y_], feed_dict={x: x_data_sample, drop: [1.0]})[0].ravel())
141+
next_char = reverse_dictionary[next_char_id]
151142
x_data_sample = np.delete(x_data_sample, 0, axis=1)
152143
x_data_sample = np.insert(x_data_sample, len(x_data_sample[0]), next_char_id, axis=1)
153144

@@ -164,6 +155,7 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
164155
batch_size = 250
165156
time_steps = 40 # size of sequence of chars
166157
learning_rate = 1e-3
158+
dropout = 0.9
167159

168160
# download and prepare data, initiate weights
169161
text = download_data()
@@ -173,41 +165,47 @@ def generate_new_text(txt, print_length, new_line, dictionary, reverse_dictionar
173165
# initiate tf placeholders
174166
x = tf.placeholder(tf.float32, [None, time_steps])
175167
y = tf.placeholder(tf.float32, [None, len(dictionary)])
168+
drop = tf.placeholder(tf.float32, [1])
176169

177170
# create other tf objects
178-
w = tf.Variable(tf.random_normal([n_hidden[-1], len(dictionary)]), dtype=tf.float32) # last layer weights
179-
logits, y_ = forward_prop(x, w, n_hidden)
171+
w = tf.Variable(tf.random_normal([n_hidden[-1], len(dictionary)]), dtype=tf.float32) # last layer weights
172+
logits, y_ = forward_prop(x, w, n_hidden, drop)
180173
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
181174
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
182175
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_, axis=1), tf.argmax(y, axis=1)), tf.float32))
176+
saver = tf.train.Saver()
183177

184178
# initiate tf session
185179
init = tf.global_variables_initializer()
186180
sess = tf.Session()
187181
sess.run(init)
188182

189183
# initiate new training session
190-
accuracy_hist = []
184+
cost_hist = []
191185
iter = 0
192186

193187
# training loop
194188
while True:
189+
iter += 1
195190

196191
# get mini batch and train
197192
x_batch, y_batch = get_mini_batch(x_data, y_data, batch_size)
198-
_ = sess.run([optimizer], feed_dict={x: x_batch, y: y_batch})
199-
200-
# other stuff
201-
iter += 1
193+
_ = sess.run([optimizer], feed_dict={x: x_batch, y: y_batch, drop: [dropout]})
202194

203195
# plot and print
204-
if iter % 50 == 0:
205-
error, acc = sess.run([cost, accuracy], feed_dict={x: x_batch, y: y_batch})
206-
accuracy_hist.append(acc)
207-
pl.plot(accuracy_hist); pl.pause(1e-99)
208-
print('Cost: %.2f' % error)
209-
210-
# generate new text by giving rnn something to start
211-
starting_txt = 'Ive had the body moved to our little mor'
212-
generate_new_text(txt=starting_txt, print_length=1000, new_line=100, dictionary=dictionary, reverse_dictionary=reverse_dictionary)
213-
196+
if iter % 100 == 0:
197+
cost_, acc = sess.run([cost, accuracy], feed_dict={x: x_batch, y: y_batch, drop: [1.0]})
198+
cost_hist.append(cost_)
199+
pl.cla()
200+
pl.plot(cost_hist)
201+
pl.pause(1e-99)
202+
203+
# generate new text by giving rnn something to start
204+
if iter % 500 == 0 or iter == 1:
205+
starting_txt = 'Ive had the body moved to our little mor'
206+
generate_new_text(txt=starting_txt, print_length=100, new_line=100, dictionary=dictionary,
207+
reverse_dictionary=reverse_dictionary)
208+
209+
# save
210+
# if iter % 500 == 0:
211+
# save_path = saver.save(sess, "model.ckpt")

0 commit comments

Comments
 (0)