Skip to content

Commit 8131cd3

Browse files
committed
Merge pull request lisa-lab#83 from caglar/master
Fixed the random number generators.
2 parents f9108d3 + 29b167a commit 8131cd3

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

Diff for: code/lstm.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
'''
44
from collections import OrderedDict
55
import cPickle as pkl
6-
import random
76
import sys
87
import time
98

@@ -17,6 +16,9 @@
1716

1817
datasets = {'imdb': (imdb.load_data, imdb.prepare_data)}
1918

19+
# Set the random number generators' seeds for consistency
20+
SEED = 123
21+
numpy.random.seed(SEED)
2022

2123
def numpy_floatX(data):
2224
return numpy.asarray(data, dtype=config.floatX)
@@ -30,7 +32,7 @@ def get_minibatches_idx(n, minibatch_size, shuffle=False):
3032
idx_list = numpy.arange(n, dtype="int32")
3133

3234
if shuffle:
33-
random.shuffle(idx_list)
35+
numpy.random.shuffle(idx_list)
3436

3537
minibatches = []
3638
minibatch_start = 0
@@ -303,7 +305,7 @@ def rmsprop(lr, tparams, grads, x, mask, y, cost):
303305

304306

305307
def build_model(tparams, options):
306-
trng = RandomStreams(1234)
308+
trng = RandomStreams(SEED)
307309

308310
# Used for dropout.
309311
use_noise = theano.shared(numpy_floatX(0.))
@@ -401,7 +403,7 @@ def train_lstm(
401403
noise_std=0.,
402404
use_dropout=True, # if False slightly faster, but worst test error
403405
# This frequently need a bigger model.
404-
reload_model="", # Path to a saved model we want to start from.
406+
reload_model=None, # Path to a saved model we want to start from.
405407
test_size=-1, # If >0, we keep only this number of test example.
406408
):
407409

@@ -419,7 +421,7 @@ def train_lstm(
419421
# size example. So we must select a random selection of the
420422
# examples.
421423
idx = numpy.arange(len(test[0]))
422-
random.shuffle(idx)
424+
numpy.random.shuffle(idx)
423425
idx = idx[:test_size]
424426
test = ([test[0][n] for n in idx], [test[1][n] for n in idx])
425427

@@ -468,6 +470,7 @@ def train_lstm(
468470
print "%d train examples" % len(train[0])
469471
print "%d valid examples" % len(valid[0])
470472
print "%d test examples" % len(test[0])
473+
471474
history_errs = []
472475
best_p = None
473476
bad_count = 0
@@ -585,7 +588,6 @@ def train_lstm(
585588
if __name__ == '__main__':
586589
# See function train for all possible parameter and there definition.
587590
train_lstm(
588-
#reload_model="lstm_model.npz",
589591
max_epochs=100,
590592
test_size=500,
591593
)

0 commit comments

Comments
 (0)