2020from lib .core .evaluate import calc_tp_fp_fn
2121from lib .core .inference import get_final_preds
2222# from utils.transforms import flip_back
23- # from utils.vis import save_debug_images
23+ from lib . utils .vis import vis_preds
2424# from utils.vis_plain_keypoint import vis_mpii_keypoints
2525# from utils.integral import softmax_integral_tensor
2626
@@ -104,7 +104,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
104104 # prefix)
105105
106106
107- def validate (config , val_loader , val_dataset , model , criterion , output_dir ,
107+ def validate (config , val_loader , model , criterion , output_dir ,
108108 tb_log_dir , writer_dict = None ):
109109 batch_time = AverageMeter ()
110110 losses = AverageMeter ()
@@ -167,8 +167,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
167167 heatmap_pred [heatmap_pred < 0.0 ] = 0
168168 heatmap_pred [heatmap_pred > 1.0 ] = 1.0
169169
170- writer .add_image ('input_recording' , input_image , global_steps ,
171- dataformats = 'CHW' )
170+ input_image = (input_image * 255 ).astype (np .uint8 )
171+ input_image = np .transpose (input_image , (1 , 2 , 0 ))
172+ pred = preds [idx ]
173+ gt = sources [idx ][:valid_source_nums [idx ], :]
174+
175+ tp , _ , _ = calc_tp_fp_fn (pred , gt )
176+ final_preds = vis_preds (input_image , pred , tp )
177+
178+ writer .add_image ('final_preds' , final_preds , global_steps ,
179+ dataformats = 'HWC' )
172180 writer .add_image ('heatmap_target' , heatmap_target , global_steps ,
173181 dataformats = 'CHW' )
174182 writer .add_image ('heatmap_pred' , heatmap_pred , global_steps ,
@@ -182,8 +190,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
182190 total_tp = total_fp = total_fn = 0
183191 for preds , target in zip (all_preds , all_gts ):
184192 tp , fp , fn = calc_tp_fp_fn (preds , target )
185- total_tp += tp
186- total_fp += fp
193+ total_tp += np . sum ( tp )
194+ total_fp += np . sum ( fp )
187195 total_fn += fn
188196
189197 recall = total_tp / (total_tp + total_fn )
@@ -204,6 +212,68 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
204212 return perf_indicator
205213
206214
215+ def test (config , test_loader , model , output_dir , tb_log_dir ,
216+ writer_dict = None ):
217+ batch_time = AverageMeter ()
218+
219+ # switch to evaluate mode
220+ model .eval ()
221+
222+ all_preds = []
223+
224+ with torch .no_grad ():
225+ end = time .time ()
226+ for i , input in enumerate (test_loader ):
227+ # compute output
228+ output = model (input )
229+
230+ num_images = input .size (0 )
231+
232+ preds = get_final_preds (output .detach ().cpu ().numpy ())
233+ all_preds .extend (preds )
234+
235+ # measure elapsed time
236+ batch_time .update (time .time () - end )
237+ end = time .time ()
238+
239+ if i % config .PRINT_FREQ == 0 :
240+ msg = 'Test: [{0}/{1}]\t ' \
241+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format (
242+ i , len (test_loader ), batch_time = batch_time )
243+ logger .info (msg )
244+
245+ if writer_dict :
246+ writer = writer_dict ['writer' ]
247+ global_steps = writer_dict ['vis_global_steps' ]
248+
249+ idx = np .random .randint (0 , num_images )
250+
251+ input_image = input .detach ().cpu ().numpy ()[idx ]
252+ min_val = input_image .min ()
253+ max_val = input_image .max ()
254+ input_image = (input_image - min_val ) / (max_val - min_val )
255+ heatmap_pred = output .detach ().cpu ().numpy ()[idx ]
256+ heatmap_pred [heatmap_pred < 0.0 ] = 0
257+ heatmap_pred [heatmap_pred > 1.0 ] = 1.0
258+
259+ input_image = (input_image * 255 ).astype (np .uint8 )
260+ input_image = np .transpose (input_image , (1 , 2 , 0 ))
261+ pred = preds [idx ]
262+ tp = np .ones (pred .shape [0 ], dtype = bool )
263+ final_preds = vis_preds (input_image , pred , tp )
264+
265+ writer .add_image ('final_preds' , final_preds , global_steps ,
266+ dataformats = 'HWC' )
267+ writer .add_image ('input_recording' , input_image , global_steps ,
268+ dataformats = 'HWC' )
269+ writer .add_image ('heatmap_pred' , heatmap_pred , global_steps ,
270+ dataformats = 'CHW' )
271+
272+ writer_dict ['vis_global_steps' ] = global_steps + 1
273+
274+ # prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
275+
276+
207277# markdown format output
208278def _print_name_value (name_value , full_arch_name ):
209279 names = name_value .keys ()
0 commit comments