20
20
from lib .core .evaluate import calc_tp_fp_fn
21
21
from lib .core .inference import get_final_preds
22
22
# from utils.transforms import flip_back
23
- # from utils.vis import save_debug_images
23
+ from lib . utils .vis import vis_preds
24
24
# from utils.vis_plain_keypoint import vis_mpii_keypoints
25
25
# from utils.integral import softmax_integral_tensor
26
26
@@ -104,7 +104,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
104
104
# prefix)
105
105
106
106
107
- def validate (config , val_loader , val_dataset , model , criterion , output_dir ,
107
+ def validate (config , val_loader , model , criterion , output_dir ,
108
108
tb_log_dir , writer_dict = None ):
109
109
batch_time = AverageMeter ()
110
110
losses = AverageMeter ()
@@ -167,8 +167,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
167
167
heatmap_pred [heatmap_pred < 0.0 ] = 0
168
168
heatmap_pred [heatmap_pred > 1.0 ] = 1.0
169
169
170
- writer .add_image ('input_recording' , input_image , global_steps ,
171
- dataformats = 'CHW' )
170
+ input_image = (input_image * 255 ).astype (np .uint8 )
171
+ input_image = np .transpose (input_image , (1 , 2 , 0 ))
172
+ pred = preds [idx ]
173
+ gt = sources [idx ][:valid_source_nums [idx ], :]
174
+
175
+ tp , _ , _ = calc_tp_fp_fn (pred , gt )
176
+ final_preds = vis_preds (input_image , pred , tp )
177
+
178
+ writer .add_image ('final_preds' , final_preds , global_steps ,
179
+ dataformats = 'HWC' )
172
180
writer .add_image ('heatmap_target' , heatmap_target , global_steps ,
173
181
dataformats = 'CHW' )
174
182
writer .add_image ('heatmap_pred' , heatmap_pred , global_steps ,
@@ -182,8 +190,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
182
190
total_tp = total_fp = total_fn = 0
183
191
for preds , target in zip (all_preds , all_gts ):
184
192
tp , fp , fn = calc_tp_fp_fn (preds , target )
185
- total_tp += tp
186
- total_fp += fp
193
+ total_tp += np . sum ( tp )
194
+ total_fp += np . sum ( fp )
187
195
total_fn += fn
188
196
189
197
recall = total_tp / (total_tp + total_fn )
@@ -204,6 +212,68 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
204
212
return perf_indicator
205
213
206
214
215
+ def test (config , test_loader , model , output_dir , tb_log_dir ,
216
+ writer_dict = None ):
217
+ batch_time = AverageMeter ()
218
+
219
+ # switch to evaluate mode
220
+ model .eval ()
221
+
222
+ all_preds = []
223
+
224
+ with torch .no_grad ():
225
+ end = time .time ()
226
+ for i , input in enumerate (test_loader ):
227
+ # compute output
228
+ output = model (input )
229
+
230
+ num_images = input .size (0 )
231
+
232
+ preds = get_final_preds (output .detach ().cpu ().numpy ())
233
+ all_preds .extend (preds )
234
+
235
+ # measure elapsed time
236
+ batch_time .update (time .time () - end )
237
+ end = time .time ()
238
+
239
+ if i % config .PRINT_FREQ == 0 :
240
+ msg = 'Test: [{0}/{1}]\t ' \
241
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format (
242
+ i , len (test_loader ), batch_time = batch_time )
243
+ logger .info (msg )
244
+
245
+ if writer_dict :
246
+ writer = writer_dict ['writer' ]
247
+ global_steps = writer_dict ['vis_global_steps' ]
248
+
249
+ idx = np .random .randint (0 , num_images )
250
+
251
+ input_image = input .detach ().cpu ().numpy ()[idx ]
252
+ min_val = input_image .min ()
253
+ max_val = input_image .max ()
254
+ input_image = (input_image - min_val ) / (max_val - min_val )
255
+ heatmap_pred = output .detach ().cpu ().numpy ()[idx ]
256
+ heatmap_pred [heatmap_pred < 0.0 ] = 0
257
+ heatmap_pred [heatmap_pred > 1.0 ] = 1.0
258
+
259
+ input_image = (input_image * 255 ).astype (np .uint8 )
260
+ input_image = np .transpose (input_image , (1 , 2 , 0 ))
261
+ pred = preds [idx ]
262
+ tp = np .ones (pred .shape [0 ], dtype = bool )
263
+ final_preds = vis_preds (input_image , pred , tp )
264
+
265
+ writer .add_image ('final_preds' , final_preds , global_steps ,
266
+ dataformats = 'HWC' )
267
+ writer .add_image ('input_recording' , input_image , global_steps ,
268
+ dataformats = 'HWC' )
269
+ writer .add_image ('heatmap_pred' , heatmap_pred , global_steps ,
270
+ dataformats = 'CHW' )
271
+
272
+ writer_dict ['vis_global_steps' ] = global_steps + 1
273
+
274
+ # prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
275
+
276
+
207
277
# markdown format output
208
278
def _print_name_value (name_value , full_arch_name ):
209
279
names = name_value .keys ()
0 commit comments