Skip to content

Commit d212d84

Browse files
committed
updated
1 parent 4940768 commit d212d84

File tree

1 file changed

+86
-34
lines changed

1 file changed

+86
-34
lines changed

rnn-example/rnn-simple.py

+86-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#! /usr/local/bin/python3
22

33
# RNN example "abba" detector
4-
#
4+
#
5+
# see this for a more complex example:
6+
# http://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
7+
#
58
# E. Culurciello, April 2017
69
#
710

@@ -15,60 +18,109 @@
1518
import torch
1619
import torch.nn as nn
1720
from torch.autograd import Variable
21+
import torch.optim as optim
22+
23+
np.set_printoptions(precision=2)
24+
print('Simple RNN model to detect a abba/0110 sequence')
1825

1926
# create a fake dataset of symbols a,b:
27+
num_symbols = 2 # a,b
2028
data_size = 256
2129
seq_len = 4 # abba sequence to be detected only!
22-
data = np.random.randint(0, 2, data_size) # 0=1, 1=b, for example
23-
label = np.zeros(data_size, dtype=int)
24-
print('dataset is:', data, 'with length:', len(data))
25-
for i in range(3, data_size-1):
26-
if (data[i-3]==0 and data[i-2]==1 and data[i-1]==1 and data[i]==0):
27-
label[i] += 1
30+
rdata = np.random.randint(0, num_symbols, data_size) # 0=1, 1=b, for example
31+
32+
# turn it into 1-hot encoding:
33+
data = np.empty([data_size, num_symbols])
34+
for i in range(0, data_size):
35+
data[i,:] = ( rdata[i], not rdata[i] )
36+
37+
print('dataset is:', data, 'with size:', data.shape)
38+
39+
# create labels:
40+
label = np.zeros([data_size, num_symbols])
41+
count = 0
42+
for i in range(3, data_size):
43+
label[i,:] = (1,0)
44+
if (rdata[i-3]==0 and rdata[i-2]==1 and rdata[i-1]==1 and rdata[i]==0):
45+
label[i,:] = (0,1)
46+
count += 1
2847

29-
print('labels is:', label, 'total number of example sequences:', np.sum(label))
48+
print('labels is:', label, 'total number of example sequences:', count)
3049

3150

3251
# create model:
33-
model = nn.RNN(1,1,1)
34-
criterion = nn.L1Loss()
52+
model = nn.RNN(num_symbols, num_symbols, 1) # see: http://pytorch.org/docs/nn.html#rnn
53+
criterion = nn.MSELoss()
54+
optimizer = optim.Adam(model.parameters(), lr=0.005)
3555

36-
# test model:
37-
# inp = Variable(torch.randn(seq_len).view(seq_len,1,1))
38-
# h0 = Variable(torch.randn(seq_len).view(seq_len,1,1))
56+
# test model, see: http://pytorch.org/docs/nn.html#rnn
57+
# inp = torch.zeros(seq_len, 1, num_symbols)
58+
# inp[0,0,0]=1
59+
# inp[1,0,1]=1
60+
# inp[2,0,1]=1
61+
# inp[3,0,0]=1
62+
# h0 = torch.zeros(1, 1, num_symbols)
3963
# print(inp, h0)
40-
# output, hn = model(inp, h0)
64+
# output, hn = model( Variable(inp), Variable(h0))
4165
# print('model test:', output,hn)
4266

4367

68+
num_epochs = 4
69+
70+
4471
def train():
4572
model.train()
46-
hidden = Variable(torch.zeros(1,1,1))
47-
for i in tqdm(range(0, data_size-seq_len, seq_len)):
48-
X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float())
49-
y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float())
50-
model.zero_grad()
51-
output, hidden = model(X_batch, hidden)
52-
loss = criterion(output, y_batch)
53-
loss.backward(retain_variables=True)
54-
print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy())
55-
# # print(X_batch, y_batch)
56-
if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0):
57-
print('RIGHT')
58-
print(loss.data.numpy())
73+
hidden = Variable(torch.zeros(1, 1, num_symbols))
74+
75+
for epoch in range(num_epochs): # loop over the dataset multiple times
76+
77+
running_loss = 0.0
78+
for i in range(0, data_size-seq_len, seq_len):
79+
# get inputs:
80+
inputs = torch.from_numpy( data[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
81+
labels = torch.from_numpy(label[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
82+
83+
# wrap them in Variable
84+
inputs, labels = Variable(inputs), Variable(labels)
85+
86+
# forward, backward, optimize
87+
optimizer.zero_grad()
88+
output, hidden = model(inputs, hidden)
89+
90+
loss = criterion(output, labels)
91+
loss.backward(retain_variables=True)
92+
optimizer.step()
93+
94+
# print info / statistics:
95+
# print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy())
96+
# print(inputs, labels)
97+
running_loss += loss.data[0]
98+
num_ave = 64
99+
if i % num_ave == 0: # print every ave mini-batches
100+
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / num_ave))
101+
running_loss = 0.0
102+
103+
print('Finished Training')
59104

60105

61106
def test():
62107
model.eval()
63-
hidden = Variable(torch.zeros(1,1,1))
108+
hidden = Variable(torch.zeros(1, 1, num_symbols))
64109
for i in range(0, data_size-seq_len, seq_len):
65-
X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float())
66-
y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float())
67-
output, hidden = model(X_batch, hidden)
68-
print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy())
69-
if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0):
70-
print('RIGHT')
110+
111+
inputs = torch.from_numpy( data[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
112+
labels = torch.from_numpy(label[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float()
113+
114+
inputs = Variable(inputs)
115+
116+
output, hidden = model(inputs, hidden)
117+
118+
print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy())
119+
if label[i,1]>0:
120+
print('RIGHT\n\n')
71121

72122
# train model:
123+
print('\nTRAINING ---')
73124
train()
125+
print('\n\nTESTING ---')
74126
test()

0 commit comments

Comments
 (0)