Skip to content

Commit 03bf7a8

Browse files
author
sfwang
committed
Add scripts and config files for running on real data.
1 parent ba0ada5 commit 03bf7a8

File tree

11 files changed

+469
-26
lines changed

11 files changed

+469
-26
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
GPUS: '0'
2+
DATA_DIR: ''
3+
OUTPUT_DIR: 'output'
4+
LOG_DIR: 'log'
5+
WORKERS: 0
6+
PRINT_FREQ: 1
7+
CUDNN:
8+
BENCHMARK: False
9+
DETERMINISTIC: True
10+
ENABLED: True
11+
DATASET:
12+
DATASET: 'real_dataset'
13+
ROOT: None
14+
TEST_SET: ['srep31332-s1.mat']
15+
TRAIN_SET: None
16+
X_MIN: 20
17+
Y_MIN: 190
18+
X_MAX: 148
19+
Y_MAX: 318
20+
# FLIP: true
21+
# ROT_FACTOR: 30
22+
# SCALE_FACTOR: 0.25
23+
MODEL:
24+
NAME: neuron_resnet
25+
PRETRAINED: 'models/pytorch/imagenet/resnet18-5c106cde.pth'
26+
IMAGE_SIZE:
27+
- 128
28+
- 128
29+
EXTRA:
30+
TARGET_TYPE: gaussian
31+
SIGMA: 2
32+
HEATMAP_SIZE:
33+
- 64
34+
- 64
35+
FINAL_CONV_KERNEL: 1
36+
DECONV_WITH_BIAS: false
37+
NUM_DECONV_LAYERS: 4
38+
NUM_DECONV_FILTERS:
39+
- 256
40+
- 256
41+
- 256
42+
- 256
43+
NUM_DECONV_KERNELS:
44+
- 4
45+
- 4
46+
- 4
47+
- 4
48+
NUM_LAYERS: 18
49+
LOSS:
50+
USE_TARGET_WEIGHT: False
51+
TRAIN:
52+
BATCH_SIZE: 32
53+
SHUFFLE: true
54+
BEGIN_EPOCH: 0
55+
END_EPOCH: 140
56+
RESUME: false
57+
OPTIMIZER: adam
58+
LR: 0.001
59+
LR_FACTOR: 0.1
60+
LR_STEP:
61+
- 90
62+
- 120
63+
WD: 0.0001
64+
GAMMA1: 0.99
65+
GAMMA2: 0.0
66+
MOMENTUM: 0.9
67+
NESTEROV: false
68+
TEST:
69+
BATCH_SIZE: 1
70+
FLIP_TEST: false
71+
MODEL_FILE: 'output/simulation/neuron_resnet_18/128x128_d256x3_adam_lr1e-3/model_best.pth.tar'
72+
DEBUG:
73+
DEBUG: false
74+
SAVE_BATCH_IMAGES_GT: true
75+
SAVE_BATCH_IMAGES_PRED: true
76+
SAVE_HEATMAPS_GT: true
77+
SAVE_HEATMAPS_PRED: true

lib/core/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@
7676
config.DATASET.TEST_SET = 'valid'
7777
config.DATASET.DATA_FORMAT = 'jpg'
7878

79+
config.DATASET.X_MIN = 250
80+
config.DATASET.Y_MIN = 50
81+
config.DATASET.X_MAX = 378
82+
config.DATASET.Y_MAX = 178
83+
7984
# training data augmentation
8085
config.DATASET.FLIP = True
8186
config.DATASET.SCALE_FACTOR = 0.25

lib/core/evaluate.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def calc_dists(preds, target):
19-
preds = preds.astype(np.float32)
19+
preds = preds.astype(np.float32)[:, :2]
2020
target = target.astype(np.float32)
2121

2222
dists = np.sqrt(((preds.reshape(preds.shape[0], 1, preds.shape[1]) - \
@@ -32,8 +32,8 @@ def calc_tp_fp_fn(preds, target, hit_thr=2):
3232

3333
row_ind, col_ind = linear_sum_assignment(dists)
3434

35-
tp = np.sum(dists[row_ind, col_ind] <= hit_thr)
36-
fp = dists.shape[0] - tp
37-
fn = dists.shape[1] - tp
35+
tp = dists[row_ind, col_ind] <= hit_thr
36+
fp = np.logical_not(tp)
37+
fn = dists.shape[1] - np.sum(tp)
3838

3939
return tp, fp, fn

lib/core/function.py

+76-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lib.core.evaluate import calc_tp_fp_fn
2121
from lib.core.inference import get_final_preds
2222
# from utils.transforms import flip_back
23-
# from utils.vis import save_debug_images
23+
from lib.utils.vis import vis_preds
2424
# from utils.vis_plain_keypoint import vis_mpii_keypoints
2525
# from utils.integral import softmax_integral_tensor
2626

@@ -104,7 +104,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
104104
# prefix)
105105

106106

107-
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
107+
def validate(config, val_loader, model, criterion, output_dir,
108108
tb_log_dir, writer_dict=None):
109109
batch_time = AverageMeter()
110110
losses = AverageMeter()
@@ -167,8 +167,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
167167
heatmap_pred[heatmap_pred < 0.0] = 0
168168
heatmap_pred[heatmap_pred > 1.0] = 1.0
169169

170-
writer.add_image('input_recording', input_image, global_steps,
171-
dataformats='CHW')
170+
input_image = (input_image * 255).astype(np.uint8)
171+
input_image = np.transpose(input_image, (1, 2, 0))
172+
pred = preds[idx]
173+
gt = sources[idx][:valid_source_nums[idx], :]
174+
175+
tp, _, _ = calc_tp_fp_fn(pred, gt)
176+
final_preds = vis_preds(input_image, pred, tp)
177+
178+
writer.add_image('final_preds', final_preds, global_steps,
179+
dataformats='HWC')
172180
writer.add_image('heatmap_target', heatmap_target, global_steps,
173181
dataformats='CHW')
174182
writer.add_image('heatmap_pred', heatmap_pred, global_steps,
@@ -182,8 +190,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
182190
total_tp = total_fp = total_fn = 0
183191
for preds, target in zip(all_preds, all_gts):
184192
tp, fp, fn = calc_tp_fp_fn(preds, target)
185-
total_tp += tp
186-
total_fp += fp
193+
total_tp += np.sum(tp)
194+
total_fp += np.sum(fp)
187195
total_fn += fn
188196

189197
recall = total_tp / (total_tp + total_fn)
@@ -204,6 +212,68 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
204212
return perf_indicator
205213

206214

215+
def test(config, test_loader, model, output_dir, tb_log_dir,
216+
writer_dict=None):
217+
batch_time = AverageMeter()
218+
219+
# switch to evaluate mode
220+
model.eval()
221+
222+
all_preds = []
223+
224+
with torch.no_grad():
225+
end = time.time()
226+
for i, input in enumerate(test_loader):
227+
# compute output
228+
output = model(input)
229+
230+
num_images = input.size(0)
231+
232+
preds = get_final_preds(output.detach().cpu().numpy())
233+
all_preds.extend(preds)
234+
235+
# measure elapsed time
236+
batch_time.update(time.time() - end)
237+
end = time.time()
238+
239+
if i % config.PRINT_FREQ == 0:
240+
msg = 'Test: [{0}/{1}]\t' \
241+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
242+
i, len(test_loader), batch_time=batch_time)
243+
logger.info(msg)
244+
245+
if writer_dict:
246+
writer = writer_dict['writer']
247+
global_steps = writer_dict['vis_global_steps']
248+
249+
idx = np.random.randint(0, num_images)
250+
251+
input_image = input.detach().cpu().numpy()[idx]
252+
min_val = input_image.min()
253+
max_val = input_image.max()
254+
input_image = (input_image - min_val) / (max_val - min_val)
255+
heatmap_pred = output.detach().cpu().numpy()[idx]
256+
heatmap_pred[heatmap_pred < 0.0] = 0
257+
heatmap_pred[heatmap_pred > 1.0] = 1.0
258+
259+
input_image = (input_image * 255).astype(np.uint8)
260+
input_image = np.transpose(input_image, (1, 2, 0))
261+
pred = preds[idx]
262+
tp = np.ones(pred.shape[0], dtype=bool)
263+
final_preds = vis_preds(input_image, pred, tp)
264+
265+
writer.add_image('final_preds', final_preds, global_steps,
266+
dataformats='HWC')
267+
writer.add_image('input_recording', input_image, global_steps,
268+
dataformats='HWC')
269+
writer.add_image('heatmap_pred', heatmap_pred, global_steps,
270+
dataformats='CHW')
271+
272+
writer_dict['vis_global_steps'] = global_steps + 1
273+
274+
# prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
275+
276+
207277
# markdown format output
208278
def _print_name_value(name_value, full_arch_name):
209279
names = name_value.keys()

lib/core/inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def get_final_preds(batch_heatmaps, score_thresh=0.5):
5959

6060
rows, cols = np.where(np.logical_and(local_max, heatmaps > score_thresh))
6161

62-
preds = np.vstack((cols, rows)).transpose()
62+
scores = heatmaps[rows, cols]
63+
64+
preds = np.vstack((cols, rows, scores)).transpose()
6365

6466
batch_preds.append(preds)
6567

lib/dataset/RealDataset.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import logging
2+
import os
3+
import json
4+
import numpy as np
5+
# from scipy.io import loadmat, savemat
6+
import copy
7+
8+
import h5py
9+
10+
from torch.utils.data import Dataset
11+
12+
# from lib.utils.utils import get_source, get_mesh, evalpotential, generate_heatmaps
13+
14+
logger = logging.getLogger(__name__)
15+
16+
class RealDataset(Dataset):
17+
def __init__(self, cfg, root, image_set, is_train, transform):
18+
self.cfg = cfg
19+
self.is_train = is_train
20+
21+
self.root = root
22+
self.image_set = image_set
23+
24+
if is_train:
25+
raise ValueError('Real dataset has no labels and thus cannot be trained on')
26+
# self.is_train = is_train
27+
28+
self.transform = transform
29+
30+
self.patch_width = cfg.MODEL.IMAGE_SIZE[0]
31+
self.patch_height = cfg.MODEL.IMAGE_SIZE[1]
32+
33+
self.x_min = cfg.DATASET.X_MIN
34+
self.y_min = cfg.DATASET.Y_MIN
35+
self.x_max = cfg.DATASET.X_MAX
36+
self.y_max = cfg.DATASET.Y_MAX
37+
38+
self.db = self._get_db()
39+
self.db_length = len(self.db)
40+
41+
logger.info('=> load {} samples'.format(len(self.db)))
42+
43+
def __getitem__(self, idx):
44+
the_db = copy.deepcopy(self.db[idx])
45+
46+
image = self.transform(the_db['image'])
47+
48+
return image
49+
50+
def _get_db(self):
51+
gt_db = []
52+
53+
all_frames = []
54+
for i, file_name in enumerate(self.image_set):
55+
dataset = {}
56+
file_path = os.path.join('data', self.cfg.DATASET.DATASET, file_name)
57+
with h5py.File(file_path, 'r') as f:
58+
for k, v in f.items():
59+
dataset[k] = np.array(v)
60+
61+
# assert ('D1' in dataset.keys())
62+
assert ('D2' in dataset.keys())
63+
64+
all_frames = np.expand_dims(dataset['D2'].copy().astype(np.float32), axis=-1).repeat(3, axis=-1)
65+
# all_frames_vis = np.transpose(dataset['D1'], (0, 2, 3, 1)).astype(np.uint8)
66+
67+
all_frames = all_frames[:, self.y_min:self.y_max, self.x_min:self.x_max, :]
68+
# all_frames_vis = all_frames_vis[269:270, self.y_min:self.y_max, self.x_min:self.x_max, :]
69+
70+
for frame in all_frames:
71+
gt_db.append({
72+
'image': frame
73+
})
74+
75+
return gt_db
76+
77+
def __len__(self):
78+
return self.db_length

lib/dataset/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from __future__ import print_function
1010

1111
from lib.dataset.SimulatedDataset import SimulatedDataset as simulation
12+
from lib.dataset.RealDataset import RealDataset as real_dataset

lib/utils/vis.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
##############################################################################
15+
16+
"""Detection output visualization module."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
from __future__ import unicode_literals
22+
23+
import cv2
24+
import numpy as np
25+
import os
26+
27+
# from lib.utils.colormap import colormap
28+
# import lib.utils.env as envu
29+
# import lib.utils.keypoints as keypoint_utils
30+
#
31+
# # Matplotlib requires certain adjustments in some environments
32+
# # Must happen before importing matplotlib
33+
# envu.set_up_matplotlib()
34+
# import matplotlib.pyplot as plt
35+
# from matplotlib.patches import Polygon
36+
#
37+
# plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator
38+
#
39+
#
40+
# _GRAY = (218, 227, 218)
41+
# _GREEN = (18, 127, 15)
42+
# _WHITE = (255, 255, 255)
43+
44+
def vis_preds(img, preds, tp):
45+
# Draw the detections.
46+
img = img.copy()
47+
for idx in range(preds.shape[0]):
48+
pt = preds[idx, 0].astype(np.int32) * 2, preds[idx, 1].astype(np.int32) * 2
49+
50+
if tp[idx]:
51+
cv2.circle(
52+
img, pt,
53+
radius=1, color=(255, 0, 0), thickness=-1, lineType=cv2.LINE_AA)
54+
else:
55+
cv2.circle(
56+
img, pt,
57+
radius=1, color=(0, 0, 255), thickness=-1, lineType=cv2.LINE_AA)
58+
59+
return img

0 commit comments

Comments
 (0)