|
1 | 1 | #! /usr/local/bin/python3
|
2 | 2 |
|
3 | 3 | # RNN example "abba" detector
|
4 |
| -# |
| 4 | +# |
| 5 | +# see this for a more complex example: |
| 6 | +# http://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html |
| 7 | +# |
5 | 8 | # E. Culurciello, April 2017
|
6 | 9 | #
|
7 | 10 |
|
|
15 | 18 | import torch
|
16 | 19 | import torch.nn as nn
|
17 | 20 | from torch.autograd import Variable
|
| 21 | +import torch.optim as optim |
| 22 | + |
| 23 | +np.set_printoptions(precision=2) |
| 24 | +print('Simple RNN model to detect a abba/0110 sequence') |
18 | 25 |
|
19 | 26 | # create a fake dataset of symbols a,b:
|
| 27 | +num_symbols = 2 # a,b |
20 | 28 | data_size = 256
|
21 | 29 | seq_len = 4 # abba sequence to be detected only!
|
22 |
| -data = np.random.randint(0, 2, data_size) # 0=1, 1=b, for example |
23 |
| -label = np.zeros(data_size, dtype=int) |
24 |
| -print('dataset is:', data, 'with length:', len(data)) |
25 |
| -for i in range(3, data_size-1): |
26 |
| - if (data[i-3]==0 and data[i-2]==1 and data[i-1]==1 and data[i]==0): |
27 |
| - label[i] += 1 |
| 30 | +rdata = np.random.randint(0, num_symbols, data_size) # 0=1, 1=b, for example |
| 31 | + |
| 32 | +# turn it into 1-hot encoding: |
| 33 | +data = np.empty([data_size, num_symbols]) |
| 34 | +for i in range(0, data_size): |
| 35 | + data[i,:] = ( rdata[i], not rdata[i] ) |
| 36 | + |
| 37 | +print('dataset is:', data, 'with size:', data.shape) |
| 38 | + |
| 39 | +# create labels: |
| 40 | +label = np.zeros([data_size, num_symbols]) |
| 41 | +count = 0 |
| 42 | +for i in range(3, data_size): |
| 43 | + label[i,:] = (1,0) |
| 44 | + if (rdata[i-3]==0 and rdata[i-2]==1 and rdata[i-1]==1 and rdata[i]==0): |
| 45 | + label[i,:] = (0,1) |
| 46 | + count += 1 |
28 | 47 |
|
29 |
| -print('labels is:', label, 'total number of example sequences:', np.sum(label)) |
| 48 | +print('labels is:', label, 'total number of example sequences:', count) |
30 | 49 |
|
31 | 50 |
|
32 | 51 | # create model:
|
33 |
| -model = nn.RNN(1,1,1) |
34 |
| -criterion = nn.L1Loss() |
| 52 | +model = nn.RNN(num_symbols, num_symbols, 1) # see: http://pytorch.org/docs/nn.html#rnn |
| 53 | +criterion = nn.MSELoss() |
| 54 | +optimizer = optim.Adam(model.parameters(), lr=0.005) |
35 | 55 |
|
36 |
| -# test model: |
37 |
| -# inp = Variable(torch.randn(seq_len).view(seq_len,1,1)) |
38 |
| -# h0 = Variable(torch.randn(seq_len).view(seq_len,1,1)) |
| 56 | +# test model, see: http://pytorch.org/docs/nn.html#rnn |
| 57 | +# inp = torch.zeros(seq_len, 1, num_symbols) |
| 58 | +# inp[0,0,0]=1 |
| 59 | +# inp[1,0,1]=1 |
| 60 | +# inp[2,0,1]=1 |
| 61 | +# inp[3,0,0]=1 |
| 62 | +# h0 = torch.zeros(1, 1, num_symbols) |
39 | 63 | # print(inp, h0)
|
40 |
| -# output, hn = model(inp, h0) |
| 64 | +# output, hn = model( Variable(inp), Variable(h0)) |
41 | 65 | # print('model test:', output,hn)
|
42 | 66 |
|
43 | 67 |
|
| 68 | +num_epochs = 4 |
| 69 | + |
| 70 | + |
44 | 71 | def train():
|
45 | 72 | model.train()
|
46 |
| - hidden = Variable(torch.zeros(1,1,1)) |
47 |
| - for i in tqdm(range(0, data_size-seq_len, seq_len)): |
48 |
| - X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float()) |
49 |
| - y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float()) |
50 |
| - model.zero_grad() |
51 |
| - output, hidden = model(X_batch, hidden) |
52 |
| - loss = criterion(output, y_batch) |
53 |
| - loss.backward(retain_variables=True) |
54 |
| - print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy()) |
55 |
| - # # print(X_batch, y_batch) |
56 |
| - if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0): |
57 |
| - print('RIGHT') |
58 |
| - print(loss.data.numpy()) |
| 73 | + hidden = Variable(torch.zeros(1, 1, num_symbols)) |
| 74 | + |
| 75 | + for epoch in range(num_epochs): # loop over the dataset multiple times |
| 76 | + |
| 77 | + running_loss = 0.0 |
| 78 | + for i in range(0, data_size-seq_len, seq_len): |
| 79 | + # get inputs: |
| 80 | + inputs = torch.from_numpy( data[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float() |
| 81 | + labels = torch.from_numpy(label[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float() |
| 82 | + |
| 83 | + # wrap them in Variable |
| 84 | + inputs, labels = Variable(inputs), Variable(labels) |
| 85 | + |
| 86 | + # forward, backward, optimize |
| 87 | + optimizer.zero_grad() |
| 88 | + output, hidden = model(inputs, hidden) |
| 89 | + |
| 90 | + loss = criterion(output, labels) |
| 91 | + loss.backward(retain_variables=True) |
| 92 | + optimizer.step() |
| 93 | + |
| 94 | + # print info / statistics: |
| 95 | + # print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy()) |
| 96 | + # print(inputs, labels) |
| 97 | + running_loss += loss.data[0] |
| 98 | + num_ave = 64 |
| 99 | + if i % num_ave == 0: # print every ave mini-batches |
| 100 | + print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / num_ave)) |
| 101 | + running_loss = 0.0 |
| 102 | + |
| 103 | + print('Finished Training') |
59 | 104 |
|
60 | 105 |
|
61 | 106 | def test():
|
62 | 107 | model.eval()
|
63 |
| - hidden = Variable(torch.zeros(1,1,1)) |
| 108 | + hidden = Variable(torch.zeros(1, 1, num_symbols)) |
64 | 109 | for i in range(0, data_size-seq_len, seq_len):
|
65 |
| - X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float()) |
66 |
| - y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float()) |
67 |
| - output, hidden = model(X_batch, hidden) |
68 |
| - print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy()) |
69 |
| - if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0): |
70 |
| - print('RIGHT') |
| 110 | + |
| 111 | + inputs = torch.from_numpy( data[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float() |
| 112 | + labels = torch.from_numpy(label[i:i+seq_len,:]).view(seq_len, 1, num_symbols).float() |
| 113 | + |
| 114 | + inputs = Variable(inputs) |
| 115 | + |
| 116 | + output, hidden = model(inputs, hidden) |
| 117 | + |
| 118 | + print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy()) |
| 119 | + if label[i,1]>0: |
| 120 | + print('RIGHT\n\n') |
71 | 121 |
|
72 | 122 | # train model:
|
| 123 | +print('\nTRAINING ---') |
73 | 124 | train()
|
| 125 | +print('\n\nTESTING ---') |
74 | 126 | test()
|
0 commit comments