20
20
from planenet import PlaneNet
21
21
from RecordReader import *
22
22
from RecordReaderRGBD import *
23
+ from RecordReaderScanNet import *
23
24
24
25
#training_flag: toggle dropout and batch normalization mode
25
26
#it's true for training and false for validation, testing, prediction
29
30
def build_graph (img_inp_train , img_inp_val , img_inp_rgbd_train , img_inp_rgbd_val , training_flag , options ):
30
31
with tf .device ('/gpu:%d' % options .gpu_id ):
31
32
img_inp_rgbd = tf .cond (tf .equal (training_flag % 2 , 0 ), lambda : img_inp_rgbd_train , lambda : img_inp_rgbd_val )
32
- img_inp = tf .cond (tf .less (training_flag , 2 ), lambda : tf .cond (tf .equal (training_flag % 2 , 0 ), lambda : img_inp_train , lambda : img_inp_val ), lambda : img_inp_rgbd )
33
+ img_inp = tf .cond (tf .equal (training_flag % 2 , 0 ), lambda : img_inp_train , lambda : img_inp_val )
34
+ img_inp = tf .cond (tf .less (training_flag , 2 ), lambda : img_inp , lambda : img_inp_rgbd )
33
35
34
36
net = PlaneNet ({'img_inp' : img_inp }, is_training = tf .equal (training_flag % 2 , 0 ), options = options )
35
37
@@ -94,6 +96,10 @@ def build_loss_rgbd(global_pred_dict, local_pred_dict, deep_pred_dicts, global_g
94
96
validDepthMask = tf .cast (tf .greater (global_gt_dict ['depth' ], 1e-4 ), tf .float32 )
95
97
depth_loss = tf .reduce_mean (tf .reduce_sum (tf .squared_difference (all_depths , global_gt_dict ['depth' ]) * all_segmentations_softmax , axis = 3 , keep_dims = True ) * validDepthMask ) * 1000
96
98
99
+ if options .predictPixelwise == 1 :
100
+ depth_loss += tf .reduce_mean (tf .squared_difference (global_pred_dict ['non_plane_mask' ], global_gt_dict ['depth' ]) * validDepthMask ) * 1000
101
+ pass
102
+
97
103
#non plane mask loss
98
104
segmentation_loss = tf .reduce_mean (tf .slice (all_segmentations_softmax , [0 , 0 , 0 , options .numOutputPlanes ], [options .batchSize , HEIGHT , WIDTH , 1 ])) * 100
99
105
@@ -148,7 +154,7 @@ def build_loss(global_pred_dict, local_pred_dict, deep_pred_dicts, global_gt_dic
148
154
149
155
plane_gt_shuffled = tf .transpose (tf .matmul (global_gt_dict ['plane' ], forward_map , transpose_a = True ), [0 , 2 , 1 ]) / tf .maximum (num_matches , 1e-4 )
150
156
plane_confidence_gt = tf .cast (num_matches > 0.5 , tf .float32 )
151
- plane_loss += tf .reduce_mean (tf .squared_difference (pred_dict ['plane' ], plane_gt_shuffled ) * plane_confidence_gt ) * 1000
157
+ plane_loss += tf .reduce_mean (tf .squared_difference (pred_dict ['plane' ], plane_gt_shuffled ) * plane_confidence_gt ) * 10000
152
158
153
159
154
160
#all segmentations is the concatenation of plane segmentations and non plane mask
@@ -200,8 +206,13 @@ def build_loss(global_pred_dict, local_pred_dict, deep_pred_dicts, global_gt_dic
200
206
validDepthMask = tf .cast (tf .greater (global_gt_dict ['depth' ], 1e-4 ), tf .float32 )
201
207
depth_loss = tf .reduce_mean (tf .reduce_sum (tf .squared_difference (all_depths , global_gt_dict ['depth' ]) * all_segmentations_softmax , axis = 3 , keep_dims = True ) * validDepthMask ) * 1000
202
208
203
- #normal loss for non-plane region
204
- normal_loss = tf .reduce_mean (tf .squared_difference (global_pred_dict ['non_plane_normal' ], global_gt_dict ['normal' ]) * (1 - plane_mask )) * 1000
209
+ if options .predictPixelwise == 1 :
210
+ depth_loss += tf .reduce_mean (tf .squared_difference (global_pred_dict ['non_plane_mask' ], global_gt_dict ['depth' ]) * validDepthMask ) * 1000
211
+ normal_loss = tf .reduce_mean (tf .squared_difference (global_pred_dict ['non_plane_normal' ], global_gt_dict ['normal' ]) * validDepthMask ) * 1000
212
+ else :
213
+ #normal loss for non-plane region
214
+ normal_loss = tf .reduce_mean (tf .squared_difference (global_pred_dict ['non_plane_normal' ], global_gt_dict ['normal' ]) * (1 - plane_mask )) * 1000
215
+ pass
205
216
206
217
207
218
#local loss
@@ -293,14 +304,6 @@ def build_loss(global_pred_dict, local_pred_dict, deep_pred_dicts, global_gt_dic
293
304
#we predict boundaries directly for post-processing purpose
294
305
boundary_loss += tf .reduce_mean (tf .losses .sigmoid_cross_entropy (logits = global_pred_dict ['boundary' ], multi_class_labels = boundary_gt , weights = tf .maximum (global_gt_dict ['boundary' ] * 3 , 1 ))) * 1000
295
306
296
-
297
- if options .diverseLoss :
298
- plane_diff = tf .reduce_sum (tf .pow (tf .expand_dims (global_pred_dict ['plane' ], 1 ) - tf .expand_dims (global_pred_dict ['plane' ], 2 ), 2 ), axis = 3 )
299
- plane_diff = tf .matrix_set_diag (plane_diff , tf .ones ((options .batchSize , options .numOutputPlanes )))
300
- minPlaneDiff = 0.1
301
- diverse_loss += tf .reduce_mean (tf .clip_by_value (1 - plane_diff / minPlaneDiff , 0 , 1 )) * 10000
302
- pass
303
-
304
307
305
308
#regularization
306
309
l2_losses = tf .add_n ([options .l2Weight * tf .nn .l2_loss (v ) for v in tf .trainable_variables () if 'weights' in v .name ])
@@ -392,7 +395,7 @@ def main(options):
392
395
sess .run (init_op )
393
396
if options .restore == 0 :
394
397
#fine-tune from DeepLab model
395
- var_to_restore = [v for v in var_to_restore if 'res5d' not in v .name and 'segmentation' not in v .name and 'plane' not in v .name and 'deep_supervision' not in v .name and 'local' not in v .name and 'boundary' not in v .name and 'degridding' not in v .name ]
398
+ var_to_restore = [v for v in var_to_restore if 'res5d' not in v .name and 'segmentation' not in v .name and 'plane' not in v .name and 'deep_supervision' not in v .name and 'local' not in v .name and 'boundary' not in v .name and 'degridding' not in v .name and 'res2a_branch2a' not in v . name and 'res2a_branch1' not in v . name ]
396
399
pretrained_model_loader = tf .train .Saver (var_to_restore )
397
400
pretrained_model_loader .restore (sess ,"../pretrained_models/deeplab_resnet.ckpt" )
398
401
elif options .restore == 1 :
@@ -406,8 +409,11 @@ def main(options):
406
409
loader = tf .train .Saver (var_to_restore )
407
410
loader .restore (sess ,"%s/checkpoint.ckpt" % (options .checkpoint_dir ))
408
411
sess .run (batchno .assign (1 ))
409
- elif options .restore == 3 :
412
+ elif options .restore == 3 :
410
413
#restore the same model from standard training
414
+ if options .predictBoundary == 1 :
415
+ var_to_restore = [v for v in var_to_restore if 'boundary' not in v .name ]
416
+ pass
411
417
if options .predictConfidence == 1 :
412
418
var_to_restore = [v for v in var_to_restore if 'confidence' not in v .name ]
413
419
pass
@@ -454,6 +460,7 @@ def main(options):
454
460
batchType = 0
455
461
pass
456
462
463
+
457
464
_ , total_loss , losses , losses_rgbd , summary_str = sess .run ([train_op , loss , loss_dict , loss_dict_rgbd , summary_op ], feed_dict = {training_flag : batchType })
458
465
writers [batchType ].add_summary (summary_str , bno )
459
466
ema [batchType ] = ema [batchType ] * MOVING_AVERAGE_DECAY + total_loss
@@ -467,6 +474,14 @@ def main(options):
467
474
pass
468
475
469
476
print bno ,'train' , ema [0 ] / ema_acc [0 ], 'val' , ema [1 ] / ema_acc [1 ], 'train rgbd' , ema [2 ] / ema_acc [2 ], 'val rgbd' , ema [3 ] / ema_acc [3 ], 'loss' , total_loss , 'time' , time .time ()- t0
477
+
478
+ if np .random .random () < 0.01 :
479
+ if batchType < 2 :
480
+ print (losses )
481
+ else :
482
+ print (losses_rgbd )
483
+ pass
484
+ pass
470
485
continue
471
486
472
487
except tf .errors .OutOfRangeError :
@@ -1056,7 +1071,10 @@ def parse_args():
1056
1071
default = 0 , type = int )
1057
1072
parser .add_argument ('--predictConfidence' , dest = 'predictConfidence' ,
1058
1073
help = 'whether predict plane confidence or not: [0, 1]' ,
1059
- default = 0 , type = int )
1074
+ default = 0 , type = int )
1075
+ parser .add_argument ('--predictPixelwise' , dest = 'predictPixelwise' ,
1076
+ help = 'whether predict pixelwise depth or not: [0, 1]' ,
1077
+ default = 0 , type = int )
1060
1078
parser .add_argument ('--fineTuningCheckpoint' , dest = 'fineTuningCheckpoint' ,
1061
1079
help = 'specify the model for fine-tuning' ,
1062
1080
default = '../PlaneSetGeneration/dump_planenet_diverse/train_planenet_diverse.ckpt' , type = str )
@@ -1105,6 +1123,9 @@ def parse_args():
1105
1123
if args .predictConfidence == 1 :
1106
1124
args .keyname += '_pc'
1107
1125
pass
1126
+ if args .predictPixelwise == 1 :
1127
+ args .keyname += '_pp'
1128
+ pass
1108
1129
if args .sameMatching == 0 :
1109
1130
args .keyname += '_sm0'
1110
1131
pass
0 commit comments