Skip to content

Commit 4940768

Browse files
committed
added simple rnn demo
1 parent 0e1517b commit 4940768

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

rnn-example/rnn-simple.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#! /usr/local/bin/python3
2+
3+
# RNN example "abba" detector
4+
#
5+
# E. Culurciello, April 2017
6+
#
7+
8+
import sys
9+
import os
10+
import time
11+
import numpy as np
12+
import argparse
13+
from tqdm import tqdm
14+
from pathlib import Path
15+
import torch
16+
import torch.nn as nn
17+
from torch.autograd import Variable
18+
19+
# create a fake dataset of symbols a,b:
20+
data_size = 256
21+
seq_len = 4 # abba sequence to be detected only!
22+
data = np.random.randint(0, 2, data_size) # 0=1, 1=b, for example
23+
label = np.zeros(data_size, dtype=int)
24+
print('dataset is:', data, 'with length:', len(data))
25+
for i in range(3, data_size-1):
26+
if (data[i-3]==0 and data[i-2]==1 and data[i-1]==1 and data[i]==0):
27+
label[i] += 1
28+
29+
print('labels is:', label, 'total number of example sequences:', np.sum(label))
30+
31+
32+
# create model:
33+
model = nn.RNN(1,1,1)
34+
criterion = nn.L1Loss()
35+
36+
# test model:
37+
# inp = Variable(torch.randn(seq_len).view(seq_len,1,1))
38+
# h0 = Variable(torch.randn(seq_len).view(seq_len,1,1))
39+
# print(inp, h0)
40+
# output, hn = model(inp, h0)
41+
# print('model test:', output,hn)
42+
43+
44+
def train():
45+
model.train()
46+
hidden = Variable(torch.zeros(1,1,1))
47+
for i in tqdm(range(0, data_size-seq_len, seq_len)):
48+
X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float())
49+
y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float())
50+
model.zero_grad()
51+
output, hidden = model(X_batch, hidden)
52+
loss = criterion(output, y_batch)
53+
loss.backward(retain_variables=True)
54+
print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy())
55+
# # print(X_batch, y_batch)
56+
if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0):
57+
print('RIGHT')
58+
print(loss.data.numpy())
59+
60+
61+
def test():
62+
model.eval()
63+
hidden = Variable(torch.zeros(1,1,1))
64+
for i in range(0, data_size-seq_len, seq_len):
65+
X_batch = Variable(torch.from_numpy(data[i:i+seq_len]).view(seq_len,1,1).float())
66+
y_batch = Variable(torch.from_numpy(label[i:i+seq_len]).view(seq_len,1,1).float())
67+
output, hidden = model(X_batch, hidden)
68+
print('in/label/out:', data[i:i+seq_len], label[i:i+seq_len], output.data.view(1,4).numpy())
69+
if (data[i]==0 and data[i+1]==1 and data[i+2]==1 and data[i+3]==0):
70+
print('RIGHT')
71+
72+
# train model:
73+
train()
74+
test()

0 commit comments

Comments
 (0)