Skip to content

Commit 50d3092

Browse files
authored
Merge pull request #679 from nudles/master
update the rnn example
2 parents 536f7e4 + 0748d78 commit 50d3092

File tree

2 files changed

+5
-112
lines changed

2 files changed

+5
-112
lines changed

examples/rnn/sample.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

examples/rnn/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
8686
data = [self.char_to_idx[c] for c in self.raw_data]
8787
# seq_length + 1 for the data + label
8888
nsamples = len(data) // (1 + seq_length)
89-
data = data[0:300 * (1 + seq_length)]
89+
data = data[0: nsamples * (1 + seq_length)]
9090
data = np.asarray(data, dtype=np.int32)
9191
data = np.reshape(data, (-1, seq_length + 1))
9292
# shuffle all sequences
@@ -172,13 +172,13 @@ def sample(model, data, dev, nsamples=100, use_max=False):
172172
y = tensor.softmax(outputs[-1])
173173

174174

175-
def evaluate(model, data, batch_size, seq_length, dev):
175+
def evaluate(model, data, batch_size, seq_length, dev, inputs, labels):
176176
model.eval()
177177
val_loss = 0.0
178178
for b in range(data.num_test_batch):
179179
batch = data.val_dat[b * batch_size:(b + 1) * batch_size]
180180
inputs, labels = convert(batch, batch_size, seq_length, data.vocab_size,
181-
dev)
181+
dev, inputs, labels)
182182
model.reset_states(dev)
183183
y = model(inputs)
184184
loss = model.loss(y, labels)[0]
@@ -217,8 +217,8 @@ def train(data,
217217
print('\nEpoch %d, train loss is %f' %
218218
(epoch, train_loss / data.num_train_batch / seq_length))
219219

220-
# evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
221-
# sample(model, data, cuda)
220+
evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
221+
sample(model, data, cuda)
222222

223223

224224
if __name__ == '__main__':

0 commit comments

Comments
 (0)