|
| 1 | +#! /usr/local/bin/python3 |
| 2 | + |
| 3 | +# RNN example "abba" detector |
| 4 | +# |
| 5 | +# E. Culurciello, April 2017 |
| 6 | +# |
| 7 | + |
| 8 | +import sys |
| 9 | +import os |
| 10 | +import time |
| 11 | +import numpy as np |
| 12 | +import argparse |
| 13 | +from tqdm import tqdm |
| 14 | +from pathlib import Path |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | +from torch.autograd import Variable |
| 18 | + |
| 19 | +# create a fake dataset of symbols a,b: |
| 20 | +data_size = 256 |
| 21 | +seq_len = 4 # abba sequence to be detected only! |
| 22 | +data = np.random.randint(0, 2, data_size) # 0=1, 1=b, for example |
| 23 | +label = np.zeros(data_size, dtype=int) |
| 24 | +print('dataset is:', data, 'with length:', len(data)) |
| 25 | +for i in range(3, data_size-1): |
| 26 | + if (data[i-3]==0 and data[i-2]==1 and data[i-1]==1 and data[i]==0): |
| 27 | + label[i] += 1 |
| 28 | + |
| 29 | +print('labels is:', label, 'total number of example sequences:', np.sum(label)) |
| 30 | + |
| 31 | + |
| 32 | +# create model: |
| 33 | +model = nn.RNN(1,1,1) |
| 34 | +criterion = nn.L1Loss() |
| 35 | + |
| 36 | +# test model: |
| 37 | +# inp = Variable(torch.randn(seq_len).view(seq_len,1,1)) |
| 38 | +# h0 = Variable(torch.randn(seq_len).view(seq_len,1,1)) |
| 39 | +# print(inp, h0) |
| 40 | +# output, hn = model(inp, h0) |
| 41 | +# print('model test:', output,hn) |
| 42 | + |
| 43 | + |
| 44 | +def train(): |
| 45 | + model.train() |
| 46 | + hidden = Variable(torch.zeros(1,1,1)) |
| 47 | + for i in tqdm(range(0, data_size-seq_len, seq_len)): |
| 48 | + X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float()) |
| 49 | + y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float()) |
| 50 | + model.zero_grad() |
| 51 | + output, hidden = model(X_batch, hidden) |
| 52 | + loss = criterion(output, y_batch) |
| 53 | + loss.backward(retain_variables=True) |
| 54 | + print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy()) |
| 55 | + # # print(X_batch, y_batch) |
| 56 | + if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0): |
| 57 | + print('RIGHT') |
| 58 | + print(loss.data.numpy()) |
| 59 | + |
| 60 | + |
| 61 | +def test(): |
| 62 | + model.eval() |
| 63 | + hidden = Variable(torch.zeros(1,1,1)) |
| 64 | + for i in range(0, data_size-seq_len, seq_len): |
| 65 | + X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float()) |
| 66 | + y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float()) |
| 67 | + output, hidden = model(X_batch, hidden) |
| 68 | + print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy()) |
| 69 | + if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0): |
| 70 | + print('RIGHT') |
| 71 | + |
| 72 | +# train model: |
| 73 | +train() |
| 74 | +test() |
0 commit comments