-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
68 lines (54 loc) · 1.93 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
from time import time, sleep
import numpy as np
import torch
from model import NTM
from data import CopyDataset
from utils import update_monitored_state
def parse_arguments():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--sequence_length', type=int, default=10)
parser.add_argument('--sequence_width', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--load', type=str, required=True)
parser.add_argument('--monitor_state', action='store_true')
return parser.parse_args()
def pprint(tensor, ndigits=2):
tensor = tensor[0] # remove batch
tensor = torch.round(tensor * (10**ndigits)) / (10**ndigits)
template = '%0.' + str(ndigits) + 'f'
for line in tensor:
print(', '.join(template % n for n in line))
def main():
torch.set_printoptions(precision=1, linewidth=240)
args = parse_arguments()
_, _, model_state, init_args, _ = torch.load(args.load)
init_args[-1] = args.monitor_state
model = NTM(*init_args)
model.load_state_dict(model_state)
dataset = CopyDataset(
args.sequence_length, args.sequence_width, 100, args.batch_size)
for idx, (x, delimeter, y) in enumerate(dataset):
if model.monitor_state:
update_monitored_state(
memory=None, read_head=None, write_head=None)
model.reset_state(args.batch_size)
if model.monitor_state:
update_monitored_state(*model.get_memory_info())
seq_len = x.shape[1]
for t in range(seq_len):
model(x[:, t])
model(delimeter)
pred = []
for i in range(seq_len):
pred.append(model())
pred = torch.stack(pred, dim=1)
print('pred')
pprint(pred)
print('y')
pprint(y)
print()
sleep(1)
if __name__ == "__main__":
main()