26
26
print ('Simple RNN model to detect a abba/0110 sequence' )
27
27
28
28
# create a fake dataset of symbols a,b:
29
+ rnn_neurons = 10
29
30
num_symbols = 2 # a,b
30
- data_size = 1024 * 8
31
+ data_size = 1024 * 4
31
32
seq_len = 4 # abba sequence to be detected only!
32
- num_layers = 4
33
+ num_layers = 2
33
34
rdata = np .random .randint (0 , num_symbols , data_size ) # 0=1, 1=b, for example
34
35
35
36
# turn it into 1-hot encoding:
52
53
53
54
54
55
# create model:
55
- model = nn .RNN (num_symbols , num_symbols , num_layers ) # see: http://pytorch.org/docs/nn.html#rnn
56
+ class Net (nn .Module ):
57
+ def __init__ (self ):
58
+ super (Net , self ).__init__ ()
59
+ self .rnn1 = nn .RNN (num_symbols , rnn_neurons , num_layers )
60
+ self .classifier1 = nn .Linear (rnn_neurons , num_symbols )
61
+
62
+ def forward (self , x , h ):
63
+ y , h = self .rnn1 (x ,h )
64
+ return self .classifier1 (y [seq_len - 1 ]), h
65
+
66
+ def init_hidden (self ):
67
+ return Variable (torch .zeros (num_layers , 1 , rnn_neurons ))
68
+
69
+
70
+ model = Net ()
56
71
criterion = nn .MSELoss ()
57
72
optimizer = optim .Adam (model .parameters (), lr = 0.001 )
58
73
@@ -81,7 +96,7 @@ def repackage_hidden(h):
81
96
82
97
def train ():
83
98
model .train ()
84
- hidden = Variable ( torch . zeros ( num_layers , 1 , num_symbols ) )
99
+ hidden = model . init_hidden ( )
85
100
86
101
for epoch in range (num_epochs ): # loop over the dataset multiple times
87
102
running_loss = 0.0
@@ -101,7 +116,7 @@ def train():
101
116
model .zero_grad ()
102
117
output , hidden = model (inputs , hidden )
103
118
104
- loss = criterion (output , labels )
119
+ loss = criterion (output , labels [ seq_len - 1 ] )
105
120
loss .backward ()
106
121
optimizer .step ()
107
122
@@ -121,8 +136,8 @@ def train():
121
136
122
137
def test ():
123
138
model .eval ()
124
- hidden = Variable ( torch . zeros ( num_layers , 1 , num_symbols ) )
125
- for i in range (0 , 32 ):
139
+ hidden = model . init_hidden ( )
140
+ for i in range (0 , 64 ):
126
141
inputs = torch .from_numpy ( data [i :i + seq_len ,:]).view (seq_len , 1 , num_symbols ).float ()
127
142
labels = torch .from_numpy (label [i :i + seq_len ,:]).view (seq_len , 1 , num_symbols ).float ()
128
143
inputs = Variable (inputs )
0 commit comments