|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +np.set_printoptions(precision=2, linewidth=200) |
| 4 | +import cv2 |
| 5 | +import os |
| 6 | +import time |
| 7 | +import sys |
| 8 | +import argparse |
| 9 | +import glob |
| 10 | + |
| 11 | +#sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 12 | +#from planenet_utils import calcPlaneDepths, drawSegmentationImage, drawDepthImage |
| 13 | +from PlaneNet.utils import calcPlaneDepths, drawSegmentationImage, drawDepthImage |
| 14 | + |
| 15 | +from train_planenet import build_graph, parse_args |
| 16 | + |
| 17 | +WIDTH = 256 |
| 18 | +HEIGHT = 192 |
| 19 | + |
| 20 | +ALL_TITLES = ['PlaneNet'] |
| 21 | +ALL_METHODS = [('sample_np10_hybrid3_bl0_dl0_ds0_crfrnn5_sm0', '', 0, 2)] |
| 22 | + |
| 23 | +class PlaneNetDetector(): |
| 24 | + def __init__(self, batchSize=1): |
| 25 | + tf.reset_default_graph() |
| 26 | + |
| 27 | + self.img_inp = tf.placeholder(tf.float32, shape=[batchSize, HEIGHT, WIDTH, 3], name='image') |
| 28 | + training_flag = tf.constant(False, tf.bool) |
| 29 | + |
| 30 | + self.options = parse_args() |
| 31 | + self.global_pred_dict, _, _ = build_graph(self.img_inp, self.img_inp, training_flag, self.options) |
| 32 | + |
| 33 | + var_to_restore = tf.global_variables() |
| 34 | + |
| 35 | + config = tf.ConfigProto() |
| 36 | + config.gpu_options.allow_growth = True |
| 37 | + config.allow_soft_placement = True |
| 38 | + init_op = tf.group(tf.global_variables_initializer(), |
| 39 | + tf.local_variables_initializer()) |
| 40 | + |
| 41 | + |
| 42 | + self.sess = tf.Session(config=config) |
| 43 | + self.sess.run(init_op) |
| 44 | + loader = tf.train.Saver(var_to_restore) |
| 45 | + path = os.path.dirname(os.path.realpath(__file__)) |
| 46 | + checkpoint_dir = path + '/checkpoint/sample_np10_hybrid3_bl0_dl0_ds0_crfrnn5_sm0' |
| 47 | + loader.restore(self.sess, "%s/checkpoint.ckpt"%(checkpoint_dir)) |
| 48 | + return |
| 49 | + |
| 50 | + def detect(self, image, estimateFocalLength=False): |
| 51 | + |
| 52 | + pred_dict = {} |
| 53 | + if True: |
| 54 | + t0 = time.time() |
| 55 | + |
| 56 | + #image_inp = np.array([cv2.resize(image, (WIDTH, HEIGHT)) for image in images]) |
| 57 | + image_inp = np.expand_dims(cv2.resize(image, (WIDTH, HEIGHT)), 0) |
| 58 | + image_inp = image_inp.astype(np.float32) / 255 - 0.5 |
| 59 | + global_pred = self.sess.run(self.global_pred_dict, feed_dict={self.img_inp: image_inp}) |
| 60 | + |
| 61 | + pred_p = global_pred['plane'][0] |
| 62 | + pred_s = global_pred['segmentation'][0] |
| 63 | + pred_np_m = global_pred['non_plane_mask'][0] |
| 64 | + pred_np_d = global_pred['non_plane_depth'][0] |
| 65 | + |
| 66 | + all_segmentations = np.concatenate([pred_s, pred_np_m], axis=2) |
| 67 | + |
| 68 | + info = np.zeros(20) |
| 69 | + if estimateFocalLength: |
| 70 | + focalLength = estimateFocalLength(img_ori) |
| 71 | + info[0] = focalLength |
| 72 | + info[5] = focalLength |
| 73 | + info[2] = img_ori.shape[1] / 2 |
| 74 | + info[6] = img_ori.shape[0] / 2 |
| 75 | + info[16] = img_ori.shape[1] |
| 76 | + info[17] = img_ori.shape[0] |
| 77 | + info[10] = 1 |
| 78 | + info[15] = 1 |
| 79 | + info[18] = 1000 |
| 80 | + info[19] = 5 |
| 81 | + else: |
| 82 | + info[0] = 571.87 |
| 83 | + info[2] = 320 |
| 84 | + info[5] = 571.87 |
| 85 | + info[6] = 240 |
| 86 | + info[16] = 640 |
| 87 | + info[17] = 480 |
| 88 | + info[10] = 1 |
| 89 | + info[15] = 1 |
| 90 | + info[18] = 1000 |
| 91 | + info[19] = 5 |
| 92 | + pass |
| 93 | + |
| 94 | + #width_high_res = images[0].shape[1] |
| 95 | + #height_high_res = images[0].shape[0] |
| 96 | + width_high_res = 640 |
| 97 | + height_high_res = 480 |
| 98 | + |
| 99 | + plane_depths = calcPlaneDepths(pred_p, width_high_res, height_high_res, info) |
| 100 | + |
| 101 | + pred_np_d = np.expand_dims(cv2.resize(pred_np_d.squeeze(), (width_high_res, height_high_res)), -1) |
| 102 | + all_depths = np.concatenate([plane_depths, pred_np_d], axis=2) |
| 103 | + |
| 104 | + all_segmentations = np.stack([cv2.resize(all_segmentations[:, :, planeIndex], (width_high_res, height_high_res)) for planeIndex in xrange(all_segmentations.shape[-1])], axis=2) |
| 105 | + |
| 106 | + segmentation = np.argmax(all_segmentations, 2) |
| 107 | + pred_d = all_depths.reshape(-1, self.options.numOutputPlanes + 1)[np.arange(height_high_res * width_high_res), segmentation.reshape(-1)].reshape(height_high_res, width_high_res) |
| 108 | + |
| 109 | + #print(pred_p) |
| 110 | + # for segmentIndex in range(segmentation.max() + 1): |
| 111 | + # cv2.imwrite('test/mask_' + str(segmentIndex) + '.png', (segmentation == segmentIndex).astype(np.uint8) * 255) |
| 112 | + # print(all_depths[:, :, segmentIndex].min(), all_depths[:, :, segmentIndex].max()) |
| 113 | + # cv2.imwrite('test/depth_' + str(segmentIndex) + '.png', drawDepthImage(all_depths[:, :, segmentIndex])) |
| 114 | + # continue |
| 115 | + pred_dict['plane'] = pred_p |
| 116 | + pred_dict['segmentation'] = segmentation |
| 117 | + pred_dict['depth'] = pred_d |
| 118 | + pred_dict['info'] = info |
| 119 | + else: |
| 120 | + print('prediction failed') |
| 121 | + pass |
| 122 | + return pred_dict |
0 commit comments