|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import threading |
| 4 | +import PIL.Image as Image |
| 5 | +from functools import partial |
| 6 | +from multiprocessing import Pool |
| 7 | +import cv2 |
| 8 | + |
| 9 | +import sys |
| 10 | +import os |
| 11 | + |
| 12 | +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 13 | +from modules import * |
| 14 | + |
| 15 | + |
| 16 | +HEIGHT=192 |
| 17 | +WIDTH=256 |
| 18 | +NUM_PLANES = 20 |
| 19 | +NUM_THREADS = 4 |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +class RecordReaderAll(): |
| 24 | + def __init__(self): |
| 25 | + return |
| 26 | + |
| 27 | + def getBatch(self, filename_queue, numOutputPlanes = 20, batchSize = 16, min_after_dequeue = 1000, random=True, getLocal=False, getSegmentation=False, test=True): |
| 28 | + reader = tf.TFRecordReader() |
| 29 | + _, serialized_example = reader.read(filename_queue) |
| 30 | + |
| 31 | + features = tf.parse_single_example( |
| 32 | + serialized_example, |
| 33 | + # Defaults are not specified since both keys are required. |
| 34 | + features={ |
| 35 | + #'height': tf.FixedLenFeature([], tf.int64), |
| 36 | + #'width': tf.FixedLenFeature([], tf.int64), |
| 37 | + 'image_raw': tf.FixedLenFeature([], tf.string), |
| 38 | + 'image_path': tf.FixedLenFeature([], tf.string), |
| 39 | + 'num_planes': tf.FixedLenFeature([], tf.int64), |
| 40 | + 'plane': tf.FixedLenFeature([NUM_PLANES * 3], tf.float32), |
| 41 | + #'plane_relation': tf.FixedLenFeature([NUM_PLANES * NUM_PLANES], tf.float32), |
| 42 | + 'segmentation_raw': tf.FixedLenFeature([], tf.string), |
| 43 | + 'depth': tf.FixedLenFeature([HEIGHT * WIDTH], tf.float32), |
| 44 | + 'normal': tf.FixedLenFeature([HEIGHT * WIDTH * 3], tf.float32), |
| 45 | + 'semantics_raw': tf.FixedLenFeature([], tf.string), |
| 46 | + 'boundary_raw': tf.FixedLenFeature([], tf.string), |
| 47 | + 'info': tf.FixedLenFeature([4 * 4 + 4], tf.float32), |
| 48 | + }) |
| 49 | + |
| 50 | + # Convert from a scalar string tensor (whose single string has |
| 51 | + # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape |
| 52 | + # [mnist.IMAGE_PIXELS]. |
| 53 | + image = tf.decode_raw(features['image_raw'], tf.uint8) |
| 54 | + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 |
| 55 | + image = tf.reshape(image, [HEIGHT, WIDTH, 3]) |
| 56 | + |
| 57 | + |
| 58 | + depth = features['depth'] |
| 59 | + depth = tf.reshape(depth, [HEIGHT, WIDTH, 1]) |
| 60 | + |
| 61 | + normal = features['normal'] |
| 62 | + normal = tf.reshape(normal, [HEIGHT, WIDTH, 3]) |
| 63 | + |
| 64 | + normal = tf.nn.l2_normalize(normal, dim=2) |
| 65 | + |
| 66 | + #normal = tf.stack([normal[:, :, 1], normal[:, :, 0], normal[:, :, 2]], axis=2) |
| 67 | + |
| 68 | + |
| 69 | + semantics = tf.decode_raw(features['semantics_raw'], tf.uint8) |
| 70 | + semantics = tf.cast(tf.reshape(semantics, [HEIGHT, WIDTH]), tf.int32) |
| 71 | + |
| 72 | + numPlanes = tf.minimum(tf.cast(features['num_planes'], tf.int32), numOutputPlanes) |
| 73 | + |
| 74 | + numPlanesOri = numPlanes |
| 75 | + numPlanes = tf.maximum(numPlanes, 1) |
| 76 | + |
| 77 | + planes = features['plane'] |
| 78 | + planes = tf.reshape(planes, [NUM_PLANES, 3]) |
| 79 | + planes = tf.slice(planes, [0, 0], [numPlanes, 3]) |
| 80 | + |
| 81 | + #shuffle_inds = tf.one_hot(tf.random_shuffle(tf.range(numPlanes)), depth = numPlanes) |
| 82 | + shuffle_inds = tf.one_hot(tf.range(numPlanes), numPlanes) |
| 83 | + |
| 84 | + planes = tf.transpose(tf.matmul(tf.transpose(planes), shuffle_inds)) |
| 85 | + planes = tf.reshape(planes, [numPlanes, 3]) |
| 86 | + planes = tf.concat([planes, tf.zeros([numOutputPlanes - numPlanes, 3])], axis=0) |
| 87 | + planes = tf.reshape(planes, [numOutputPlanes, 3]) |
| 88 | + |
| 89 | + |
| 90 | + boundary = tf.decode_raw(features['boundary_raw'], tf.uint8) |
| 91 | + boundary = tf.cast(tf.reshape(boundary, (HEIGHT, WIDTH, 2)), tf.float32) |
| 92 | + |
| 93 | + #boundary = tf.decode_raw(features['boundary_raw'], tf.float64) |
| 94 | + #boundary = tf.cast(tf.reshape(boundary, (HEIGHT, WIDTH, 3)), tf.float32) |
| 95 | + #boundary = tf.slice(boundary, [0, 0, 0], [HEIGHT, WIDTH, 2]) |
| 96 | + |
| 97 | + segmentation = tf.decode_raw(features['segmentation_raw'], tf.uint8) |
| 98 | + segmentation = tf.reshape(segmentation, [HEIGHT, WIDTH, 1]) |
| 99 | + |
| 100 | + |
| 101 | + |
| 102 | + coef = tf.range(numPlanes) |
| 103 | + coef = tf.reshape(tf.matmul(tf.reshape(coef, [-1, numPlanes]), tf.cast(shuffle_inds, tf.int32)), [1, 1, numPlanes]) |
| 104 | + |
| 105 | + plane_masks = tf.cast(tf.equal(segmentation, tf.cast(coef, tf.uint8)), tf.float32) |
| 106 | + plane_masks = tf.concat([plane_masks, tf.zeros([HEIGHT, WIDTH, numOutputPlanes - numPlanes])], axis=2) |
| 107 | + plane_masks = tf.reshape(plane_masks, [HEIGHT, WIDTH, numOutputPlanes]) |
| 108 | + |
| 109 | + #non_plane_mask = tf.cast(tf.equal(segmentation, tf.cast(numOutputPlanes, tf.uint8)), tf.float32) |
| 110 | + non_plane_mask = 1 - tf.reduce_max(plane_masks, axis=2, keep_dims=True) |
| 111 | + #tf.cast(tf.equal(segmentation, tf.cast(numOutputPlanes, tf.uint8)), tf.float32) |
| 112 | + |
| 113 | + |
| 114 | + if random: |
| 115 | + image_inp, plane_inp, depth_gt, normal_gt, semantics_gt, plane_masks_gt, boundary_gt, num_planes_gt, non_plane_mask_gt, image_path, info = tf.train.shuffle_batch([image, planes, depth, normal, semantics, plane_masks, boundary, numPlanesOri, non_plane_mask, features['image_path'], features['info']], batch_size=batchSize, capacity=min_after_dequeue + (NUM_THREADS + 2) * batchSize, num_threads=NUM_THREADS, min_after_dequeue=min_after_dequeue) |
| 116 | + else: |
| 117 | + image_inp, plane_inp, depth_gt, normal_gt, semantics_gt, plane_masks_gt, boundary_gt, num_planes_gt, non_plane_mask_gt, image_path, info = tf.train.batch([image, planes, depth, normal, semantics, plane_masks, boundary, numPlanesOri, non_plane_mask, features['image_path'], features['info']], batch_size=batchSize, capacity=(NUM_THREADS + 2) * batchSize, num_threads=1) |
| 118 | + pass |
| 119 | + global_gt_dict = {'plane': plane_inp, 'depth': depth_gt, 'normal': normal_gt, 'semantics': semantics_gt, 'segmentation': plane_masks_gt, 'boundary': boundary_gt, 'num_planes': num_planes_gt, 'non_plane_mask': non_plane_mask_gt, 'image_path': image_path, 'info': info} |
| 120 | + return image_inp, global_gt_dict, {} |
0 commit comments