Skip to content

Commit 94011c6

Browse files
author
sfwang
committed
Add evaluation code.
1 parent 7127411 commit 94011c6

File tree

7 files changed

+170
-99
lines changed

7 files changed

+170
-99
lines changed

lib/core/evaluate.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,62 +10,30 @@
1010

1111
import numpy as np
1212

13-
from core.inference import get_max_preds
13+
from scipy.optimize import linear_sum_assignment
1414

15+
from lib.core.inference import get_final_preds
1516

16-
def calc_dists(preds, target, normalize):
17+
18+
def calc_dists(preds, target):
1719
preds = preds.astype(np.float32)
1820
target = target.astype(np.float32)
19-
dists = np.zeros((preds.shape[1], preds.shape[0]))
20-
for n in range(preds.shape[0]):
21-
for c in range(preds.shape[1]):
22-
if target[n, c, 0] > 1 and target[n, c, 1] > 1:
23-
normed_preds = preds[n, c, :] / normalize[n]
24-
normed_targets = target[n, c, :] / normalize[n]
25-
dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
26-
else:
27-
dists[c, n] = -1
28-
return dists
2921

22+
dists = np.sqrt(((preds.reshape(preds.shape[0], 1, preds.shape[1]) - \
23+
target.reshape(1, target.shape[0], target.shape[1])) ** 2) \
24+
.sum(axis=-1))
25+
26+
return dists
3027

31-
def dist_acc(dists, thr=0.5):
32-
''' Return percentage below threshold while ignoring values with a -1 '''
33-
dist_cal = np.not_equal(dists, -1)
34-
num_dist_cal = dist_cal.sum()
35-
if num_dist_cal > 0:
36-
return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
37-
else:
38-
return -1
3928

29+
def calc_tp_fp_fn(preds, target, hit_thr=2):
4030

41-
def accuracy(output, target, hm_type='gaussian', thr=0.5):
42-
'''
43-
Calculate accuracy according to PCK,
44-
but uses ground truth heatmap rather than x,y locations
45-
First value to be returned is average accuracy across 'idxs',
46-
followed by individual accuracies
47-
'''
48-
idx = list(range(output.shape[1]))
49-
norm = 1.0
50-
if hm_type == 'gaussian':
51-
pred, _ = get_max_preds(output)
52-
target, _ = get_max_preds(target)
53-
h = output.shape[2]
54-
w = output.shape[3]
55-
norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
56-
dists = calc_dists(pred, target, norm)
31+
dists = calc_dists(preds, target)
5732

58-
acc = np.zeros((len(idx) + 1))
59-
avg_acc = 0
60-
cnt = 0
33+
row_ind, col_ind = linear_sum_assignment(dists)
6134

62-
for i in range(len(idx)):
63-
acc[i + 1] = dist_acc(dists[idx[i]])
64-
if acc[i + 1] >= 0:
65-
avg_acc = avg_acc + acc[i + 1]
66-
cnt += 1
35+
tp = np.sum(dists[row_ind, col_ind] <= hit_thr)
36+
fp = dists.shape[0] - tp
37+
fn = dists.shape[1] - tp
6738

68-
avg_acc = avg_acc / cnt if cnt != 0 else 0
69-
if cnt != 0:
70-
acc[0] = avg_acc
71-
return acc, avg_acc, cnt, pred
39+
return tp, fp, fn

lib/core/function.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import torch.autograd as autograd
1818

1919
from lib.core.config import get_model_name
20-
# from lib.core.evaluate import accuracy
21-
# from core.inference import get_final_preds, get_final_integral_preds
20+
from lib.core.evaluate import calc_tp_fp_fn
21+
from lib.core.inference import get_final_preds
2222
# from utils.transforms import flip_back
2323
# from utils.vis import save_debug_images
2424
# from utils.vis_plain_keypoint import vis_mpii_keypoints
@@ -113,6 +113,9 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
113113
# switch to evaluate mode
114114
model.eval()
115115

116+
all_preds = []
117+
all_gts = []
118+
116119
with torch.no_grad():
117120
end = time.time()
118121
for i, (input, target, meta) in enumerate(val_loader):
@@ -127,6 +130,13 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
127130
losses.update(loss.item(), num_images)
128131
# _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
129132
# target.cpu().numpy())
133+
preds = get_final_preds(output.detach().cpu().numpy())
134+
all_preds.extend(preds)
135+
136+
sources = meta['sources'].clone().detach().cpu().numpy()
137+
valid_source_nums = meta['valid_source_num'].clone().detach().cpu().numpy()
138+
for j, gt in enumerate(sources):
139+
all_gts.append(gt[:valid_source_nums[j], :])
130140

131141
# acc.update(avg_acc, cnt)
132142

@@ -168,12 +178,26 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
168178

169179
# prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
170180

171-
perf_indicator = losses.avg
181+
# perf_indicator = losses.avg
182+
total_tp = total_fp = total_fn = 0
183+
for preds, target in zip(all_preds, all_gts):
184+
tp, fp, fn = calc_tp_fp_fn(preds, target)
185+
total_tp += tp
186+
total_fp += fp
187+
total_fn += fn
188+
189+
recall = total_tp / (total_tp + total_fn)
190+
prec = total_tp / (total_tp + total_fp)
191+
192+
perf_indicator = 2 * prec * recall / (prec + recall)
172193

173194
if writer_dict:
174195
writer = writer_dict['writer']
175196
global_steps = writer_dict['valid_global_steps']
176197
writer.add_scalar('valid_loss', losses.avg, global_steps)
198+
writer.add_scalar('recall', recall, global_steps)
199+
writer.add_scalar('precision', prec, global_steps)
200+
writer.add_scalar('f_score', perf_indicator, global_steps)
177201

178202
writer_dict['valid_global_steps'] = global_steps + 1
179203

lib/core/inference.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Bin Xiao ([email protected])
5+
# ------------------------------------------------------------------------------
6+
7+
from __future__ import absolute_import
8+
from __future__ import division
9+
from __future__ import print_function
10+
11+
import math
12+
import cv2
13+
14+
import numpy as np
15+
16+
def get_final_preds(batch_heatmaps, score_thresh=0.5):
17+
'''
18+
get predictions from score maps
19+
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
20+
'''
21+
assert isinstance(batch_heatmaps, np.ndarray), \
22+
'batch_heatmaps should be numpy.ndarray'
23+
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
24+
25+
assert batch_heatmaps.shape[1] == 1, 'batch_images must be single channel'
26+
27+
batch_size = batch_heatmaps.shape[0]
28+
height = batch_heatmaps.shape[2]
29+
width = batch_heatmaps.shape[3]
30+
31+
# local_max_heatmaps = np.zeros((batch_size, height, width, 1), dtype=float32)
32+
batch_preds = []
33+
34+
for batch_idx in range(batch_size):
35+
heatmaps = batch_heatmaps[batch_idx]
36+
heatmaps = heatmaps.squeeze()
37+
38+
heatmaps_padded = cv2.copyMakeBorder(heatmaps, 1, 1, 1, 1, cv2.BORDER_REPLICATE)
39+
40+
local_max = np.ones(heatmaps.shape, dtype=np.bool)
41+
42+
for n_idx in range(9):
43+
if n_idx == 4:
44+
continue
45+
46+
neighbors = heatmaps_padded[
47+
(n_idx % 3):(n_idx % 3 + height),
48+
(n_idx // 3):(n_idx // 3 + width)
49+
]
50+
51+
local_max = np.logical_and(
52+
local_max,
53+
neighbors <= heatmaps
54+
)
55+
56+
# heatmaps[np.logical_not(local_max)] = 0.0
57+
58+
# local_max_heatmaps[batch_idx] = heatmaps
59+
60+
rows, cols = np.where(np.logical_and(local_max, heatmaps > score_thresh))
61+
62+
preds = np.vstack((cols, rows)).transpose()
63+
64+
batch_preds.append(preds)
65+
66+
67+
return batch_preds

lib/dataset/SimulatedDataset.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from torch.utils.data import Dataset
99

10-
from lib.utils.utils import get_source, evalpotential, generate_heatmaps
10+
from lib.utils.utils import get_source, get_mesh, evalpotential, generate_heatmaps
1111

1212
logger = logging.getLogger(__name__)
1313

1414
class SimulatedDataset(Dataset):
15-
def __init__(self, cfg, root, image_set, is_train, transform):
15+
def __init__(self, cfg, root, image_set, is_train, transform, n_sources=30):
1616
self.cfg = cfg
1717
self.is_train = is_train
1818

@@ -23,9 +23,13 @@ def __init__(self, cfg, root, image_set, is_train, transform):
2323

2424
self.transform = transform
2525

26+
self.n_sources = n_sources # only used at test time
27+
2628
self.patch_width = cfg.MODEL.IMAGE_SIZE[0]
2729
self.patch_height = cfg.MODEL.IMAGE_SIZE[1]
2830

31+
self.mesh = get_mesh(self.patch_width, self.patch_height)
32+
2933
self.db = self._get_db()
3034
self.db_length = len(self.db)
3135

@@ -34,7 +38,7 @@ def __init__(self, cfg, root, image_set, is_train, transform):
3438
def __getitem__(self, idx):
3539
the_db = copy.deepcopy(self.db[idx])
3640

37-
image = evalpotential(the_db['mesh'], the_db['sources'])
41+
image = evalpotential(self.mesh, the_db['sources'])
3842

3943
image = image.reshape((self.patch_height, self.patch_width)) + \
4044
self.cfg.MODEL.VAR_NOISE*np.random.randn(self.patch_height, self.patch_width)
@@ -53,31 +57,30 @@ def __getitem__(self, idx):
5357
self.cfg.MODEL.OUTPUT_SIZE[1]
5458
)
5559

60+
valid_source_num = unnormalized_sources.shape[0]
61+
62+
if valid_source_num < 64:
63+
unnormalized_sources = np.concatenate(
64+
(unnormalized_sources,
65+
-np.ones((64 - valid_source_num, 2))),
66+
axis=0
67+
)
68+
5669
meta = {
57-
'sources': unnormalized_sources
70+
'sources': unnormalized_sources,
71+
'valid_source_num': valid_source_num
5872
}
5973

6074
return image, heatmap_target, meta
6175

6276
def _get_db(self):
6377
gt_db = []
6478
self.db_length = int(self.cfg.TRAIN.NUM_SAMPLES) if self.is_train else int(self.cfg.TEST.NUM_SAMPLES)
65-
n_sources = 30
6679
for i in range(self.db_length):
67-
mesh, sources = get_source(self.patch_width, self.patch_height,
68-
self.cfg.MODEL.DEPTH, n_sources, self.cfg.MODEL.VAR_NOISE)
69-
70-
# image.view(-1).repeat(3).view(self.patch_width, self.patch_height, 3)
71-
# image = np.expand_dims(image, axis=-1).repeat(3, axis=-1)
72-
73-
# image = self.transform(image)
74-
75-
# sources = sources.transpose()
76-
77-
# heatmap_target = generate_heatmaps(sources, self.patch_height, self.patch_width)
80+
n_sources = np.random.randint(1, 65) if self.is_train else self.n_sources
81+
sources = get_source(self.cfg.MODEL.DEPTH, n_sources, self.cfg.MODEL.VAR_NOISE)
7882

7983
gt_db.append({
80-
'mesh': mesh,
8184
'sources': sources
8285
})
8386

lib/utils/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def evalpotential(sites_locations, sources):
120120
# recordings[i] = sum(sources[0,:]/dis_sources);
121121
return recordings
122122

123-
def get_source(x_mesh, y_mesh, depth, n_sources, var_noise):
123+
def get_mesh(x_mesh, y_mesh):
124124

125125
d = x_mesh*y_mesh
126126
x = np.linspace(0, 1, x_mesh)
@@ -129,12 +129,16 @@ def get_source(x_mesh, y_mesh, depth, n_sources, var_noise):
129129
Xsim = np.reshape(Xsim, [d])
130130
Ysim = np.reshape(Ysim, [d])
131131
mesh = np.array([Xsim,Ysim,np.zeros(d)]);
132+
return mesh
133+
134+
def get_source(depth, n_sources, var_noise):
135+
132136
sources = np.random.rand(4, n_sources);
133137
sources[3, :] = depth
134138
sources[0, :] = 2*np.floor(2*sources[0, :])-1;
135139
# image = evalpotential(mesh, sources);
136140
# image = image.reshape((y_mesh, x_mesh)) + var_noise*np.random.randn(y_mesh, x_mesh)
137-
return mesh, sources
141+
return sources
138142

139143
def generate_heatmaps(keypoints, im_height, im_width):
140144
heatmaps = np.zeros((1, int(im_height), int(im_width)), dtype=np.float32)

source_detection/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def main():
163163
pin_memory=True
164164
)
165165

166-
best_perf = 1000.0
166+
best_perf = 0.0
167167
best_model = False
168168
for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
169169
lr_scheduler.step()
@@ -178,7 +178,7 @@ def main():
178178
criterion, final_output_dir, tb_log_dir,
179179
writer_dict)
180180

181-
if perf_indicator < best_perf:
181+
if perf_indicator > best_perf:
182182
best_perf = perf_indicator
183183
best_model = True
184184
else:

0 commit comments

Comments
 (0)