Skip to content

Commit 5b72b21

Browse files
committed
fix bptt
1 parent d212d84 commit 5b72b21

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

rnn-example/rnn-simple.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
# RNN example "abba" detector
44
#
5-
# see this for a more complex example:
5+
# another simple example:
6+
# https://github.com/pytorch/examples/tree/master/time_sequence_prediction
7+
# a more complex example:
68
# http://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
79
#
810
# E. Culurciello, April 2017
@@ -27,6 +29,7 @@
2729
num_symbols = 2 # a,b
2830
data_size = 256
2931
seq_len = 4 # abba sequence to be detected only!
32+
num_layers = 3
3033
rdata = np.random.randint(0, num_symbols, data_size) # 0=1, 1=b, for example
3134

3235
# turn it into 1-hot encoding:
@@ -49,7 +52,7 @@
4952

5053

5154
# create model:
52-
model = nn.RNN(num_symbols, num_symbols, 1) # see: http://pytorch.org/docs/nn.html#rnn
55+
model = nn.RNN(num_symbols, num_symbols, num_layers) # see: http://pytorch.org/docs/nn.html#rnn
5356
criterion = nn.MSELoss()
5457
optimizer = optim.Adam(model.parameters(), lr=0.005)
5558

@@ -65,15 +68,22 @@
6568
# print('model test:', output,hn)
6669

6770

71+
def repackage_hidden(h):
72+
"""Wraps hidden states in new Variables, to detach them from their history."""
73+
if type(h) == Variable:
74+
return Variable(h.data)
75+
else:
76+
return tuple(repackage_hidden(v) for v in h)
77+
78+
6879
num_epochs = 4
6980

7081

7182
def train():
7283
model.train()
73-
hidden = Variable(torch.zeros(1, 1, num_symbols))
84+
hidden = Variable(torch.zeros(num_layers, 1, num_symbols))
7485

7586
for epoch in range(num_epochs): # loop over the dataset multiple times
76-
7787
running_loss = 0.0
7888
for i in range(0, data_size-seq_len, seq_len):
7989
# get inputs:
@@ -83,12 +93,16 @@ def train():
8393
# wrap them in Variable
8494
inputs, labels = Variable(inputs), Variable(labels)
8595

96+
# Starting each batch, we detach the hidden state from how it was previously produced.
97+
# If we didn't, the model would try backpropagating all the way to start of the dataset.
98+
hidden = repackage_hidden(hidden)
99+
86100
# forward, backward, optimize
87-
optimizer.zero_grad()
101+
model.zero_grad()
88102
output, hidden = model(inputs, hidden)
89103

90104
loss = criterion(output, labels)
91-
loss.backward(retain_variables=True)
105+
loss.backward()
92106
optimizer.step()
93107

94108
# print info / statistics:
@@ -105,19 +119,14 @@ def train():
105119

106120
def test():
107121
model.eval()
108-
hidden = Variable(torch.zeros(1, 1, num_symbols))
122+
hidden = Variable(torch.zeros(num_layers, 1, num_symbols))
109123
for i in range(0, data_size-seq_len, seq_len):
110-
111124
inputs = torch.from_numpy( data[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
112125
labels = torch.from_numpy(label[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
113-
114-
inputs = Variable(inputs)
115-
126+
inputs = Variable(inputs)
116127
output, hidden = model(inputs, hidden)
117-
118128
print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy())
119-
if label[i,1]>0:
120-
print('RIGHT\n\n')
129+
121130

122131
# train model:
123132
print('\nTRAINING ---')

0 commit comments

Comments
 (0)