27
27
28
28
# create a fake dataset of symbols a,b:
29
29
num_symbols = 2 # a,b
30
- data_size = 256
30
+ data_size = 1024 * 8
31
31
seq_len = 4 # abba sequence to be detected only!
32
- num_layers = 3
32
+ num_layers = 4
33
33
rdata = np .random .randint (0 , num_symbols , data_size ) # 0=1, 1=b, for example
34
34
35
35
# turn it into 1-hot encoding:
36
36
data = np .empty ([data_size , num_symbols ])
37
37
for i in range (0 , data_size ):
38
- data [i ,:] = ( rdata [i ], not rdata [i ] )
38
+ data [i ,:] = ( rdata [i ], not rdata [i ] ) # only works for 2 symbols for now!
39
39
40
40
print ('dataset is:' , data , 'with size:' , data .shape )
41
41
42
42
# create labels:
43
43
label = np .zeros ([data_size , num_symbols ])
44
44
count = 0
45
- for i in range (3 , data_size ):
45
+ for i in range (seq_len , data_size ):
46
46
label [i ,:] = (1 ,0 )
47
47
if (rdata [i - 3 ]== 0 and rdata [i - 2 ]== 1 and rdata [i - 1 ]== 1 and rdata [i ]== 0 ):
48
48
label [i ,:] = (0 ,1 )
54
54
# create model:
55
55
model = nn .RNN (num_symbols , num_symbols , num_layers ) # see: http://pytorch.org/docs/nn.html#rnn
56
56
criterion = nn .MSELoss ()
57
- optimizer = optim .Adam (model .parameters (), lr = 0.005 )
57
+ optimizer = optim .Adam (model .parameters (), lr = 0.001 )
58
58
59
59
# test model, see: http://pytorch.org/docs/nn.html#rnn
60
60
# inp = torch.zeros(seq_len, 1, num_symbols)
@@ -108,6 +108,8 @@ def train():
108
108
# print info / statistics:
109
109
# print('in:', data[i:i+seq_len,0], 'label:', label[i:i+seq_len,1], 'out:', output.data.numpy())
110
110
# print(inputs, labels)
111
+ # input()
112
+
111
113
running_loss += loss .data [0 ]
112
114
num_ave = 64
113
115
if i % num_ave == 0 : # print every ave mini-batches
@@ -120,10 +122,11 @@ def train():
120
122
def test ():
121
123
model .eval ()
122
124
hidden = Variable (torch .zeros (num_layers , 1 , num_symbols ))
123
- for i in range (0 , data_size - seq_len , seq_len ):
125
+ for i in range (0 , 32 ):
124
126
inputs = torch .from_numpy ( data [i :i + seq_len ,:]).view (seq_len , 1 , num_symbols ).float ()
125
127
labels = torch .from_numpy (label [i :i + seq_len ,:]).view (seq_len , 1 , num_symbols ).float ()
126
- inputs = Variable (inputs )
128
+ inputs = Variable (inputs )
129
+ hidden = repackage_hidden (hidden )
127
130
output , hidden = model (inputs , hidden )
128
131
print ('in:' , data [i :i + seq_len ,0 ], 'label:' , label [i :i + seq_len ,1 ], 'out:' , output .data .numpy ())
129
132
0 commit comments