-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
102 lines (89 loc) · 3.36 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import sys
from time import time
import pickle
from os import makedirs
from os.path import join
import argparse
import torch
from torch import nn, sigmoid
from torch.utils.data import DataLoader
from model import UNet3D
from dataset import MRIDataset
from utils import Report, transfer_weights
"""
Run inference on test set.
This script also saves voxel-wise inference results and labels on numpy arrays
for future reuse (without running on gpu).
"""
argv = sys.argv[1:]
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
prog='PROG',)
parser.add_argument('--data_dir',
type=str,
required=True,
help="path to images (should have subdirs: test)")
parser.add_argument('--model_path', '-m', type=str,
required=True,
help='path to a model snapshot')
parser.add_argument('--size', type=int, nargs='+', default=[128])
parser.add_argument('--dump_name',
type=str,
default='',
help='name dump file to avoid overwrting files.')
args = parser.parse_args(argv)
net = UNet3D(1, 1, use_bias=True, inplanes=16)
transfer_weights(net, args.model_path)
net.cuda()
net.train(False)
torch.no_grad()
test_dir = join(args.data_dir, 'test')
batch_size = 1
def inference(target_dir):
volume_size = args.size*3 if len(args.size) == 1 else args.size
dataloader = DataLoader(MRIDataset(target_dir,
volume_size,
sampling_mode='center',
deterministic=True),
batch_size=1,
num_workers=4)
input_paths = dataloader.dataset.inputs
label_paths = dataloader.dataset.labels
reporter = Report()
preds_list = list()
labels_list = list()
sum_hd = 0
sum_sd = 0
num_voxels_hd = list()
for i, (inputs, labels) in enumerate(dataloader):
inputs = inputs.cuda()
preds = sigmoid(net(inputs.detach()).detach())
full_labels = dataloader.dataset._load_full_label(label_paths[i])
full_preds = dataloader.dataset._project_full_label(input_paths[i],
preds.cpu())
preds = full_preds
labels = full_labels
reporter.feed(preds, labels)
temp_reporter = Report()
temp_reporter.feed(preds, labels)
sum_hd += temp_reporter.hard_dice()
sum_sd += temp_reporter.soft_dice()
num_voxels_hd.append(
(labels.view(-1).sum().item(), temp_reporter.hard_dice()))
del temp_reporter
preds_list.append(preds.cpu())
labels_list.append(labels.cpu())
del inputs, labels, preds
print("Micro Averaged Dice {}, {}".format(sum_hd / len(dataloader),
sum_sd / len(dataloader)))
if len(args.dump_name):
# dump preds for visualization
pickle.dump([input_paths, preds_list, labels_list],
open('preds_dump_{}.pickle'.format(args.dump_name), 'wb'))
print(reporter)
print(reporter.stats())
preds = torch.stack(preds_list).view(-1).numpy()
labels = torch.stack(labels_list).view(-1).numpy().astype(int)
print(num_voxels_hd)
return preds, labels
preds, labels = inference(test_dir)