|
36 | 36 | train_dataset = train_dataset.batch(args.batch_size) |
37 | 37 | train_dataset = train_dataset.map( |
38 | 38 | lambda x: tf.py_func(get_batch_data, |
39 | | - inp=[x, args.class_num, args.img_size, args.anchors, 'train', args.multi_scale_train, args.use_mix_up], |
| 39 | + inp=[x, args.class_num, args.img_size, args.anchors, 'train', args.multi_scale_train, args.use_mix_up, args.letterbox_resize], |
40 | 40 | Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), |
41 | 41 | num_parallel_calls=args.num_threads |
42 | 42 | ) |
|
46 | 46 | val_dataset = val_dataset.batch(1) |
47 | 47 | val_dataset = val_dataset.map( |
48 | 48 | lambda x: tf.py_func(get_batch_data, |
49 | | - inp=[x, args.class_num, args.img_size, args.anchors, 'val', False, False], |
| 49 | + inp=[x, args.class_num, args.img_size, args.anchors, 'val', False, False, args.letterbox_resize], |
50 | 50 | Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), |
51 | 51 | num_parallel_calls=args.num_threads |
52 | 52 | ) |
|
107 | 107 | # set dependencies for BN ops |
108 | 108 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
109 | 109 | with tf.control_dependencies(update_ops): |
110 | | - train_op = optimizer.minimize(loss[0] + l2_loss, var_list=update_vars, global_step=global_step) |
| 110 | + # train_op = optimizer.minimize(loss[0] + l2_loss, var_list=update_vars, global_step=global_step) |
| 111 | + # apple gradient clip to avoid gradient exploding |
| 112 | + gvs = optimizer.compute_gradients(loss[0] + l2_loss, var_list=update_vars) |
| 113 | + clip_grad_var = [gv if gv[0] is None else [ |
| 114 | + tf.clip_by_norm(gv[0], 50.), gv[1]] for gv in gvs] |
| 115 | + train_op = optimizer.apply_gradients(clip_grad_var, global_step=global_step) |
111 | 116 |
|
112 | 117 | if args.save_optimizer: |
113 | 118 | print('Saving optimizer parameters to checkpoint! Remember to restore the global_step in the fine-tuning afterwards.') |
|
166 | 171 | saver_to_save.save(sess, args.save_dir + 'model-epoch_{}_step_{}_loss_{:.4f}_lr_{:.5g}'.format(epoch, int(__global_step), loss_total.average, __lr)) |
167 | 172 |
|
168 | 173 | # switch to validation dataset for evaluation |
169 | | - if epoch % args.val_evaluation_epoch == 0 and epoch > 0: |
| 174 | + if epoch % args.val_evaluation_epoch == 0 and epoch >= args.warm_up_epoch: |
170 | 175 | sess.run(val_init_op) |
171 | 176 |
|
172 | 177 | val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = \ |
|
187 | 192 |
|
188 | 193 | # calc mAP |
189 | 194 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() |
190 | | - gt_dict = parse_gt_rec(args.val_file, args.img_size) |
| 195 | + gt_dict = parse_gt_rec(args.val_file, args.img_size, args.letterbox_resize) |
191 | 196 |
|
192 | 197 | info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) |
193 | 198 |
|
194 | 199 | for ii in range(args.class_num): |
195 | | - npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=False) |
| 200 | + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=args.use_voc_07_metric) |
196 | 201 | info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) |
197 | 202 | rec_total.update(rec, npos) |
198 | 203 | prec_total.update(prec, nd) |
|
0 commit comments