Skip to content

Commit 3aa718a

Browse files
authored
Add files via upload
added test results and modified file to convert data in record format
1 parent f7b7b3a commit 3aa718a

File tree

7 files changed

+251
-0
lines changed

7 files changed

+251
-0
lines changed

extra/create_mask_rcnn_tf_record.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Convert the Oxford pet dataset to TFRecord for object_detection.
17+
18+
See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
19+
Cats and Dogs
20+
IEEE Conference on Computer Vision and Pattern Recognition, 2012
21+
http://www.robots.ox.ac.uk/~vgg/data/pets/
22+
23+
Example usage:
24+
python object_detection/dataset_tools/create_pet_tf_record.py \
25+
--data_dir=/home/user/pet \
26+
--output_dir=/home/user/pet/output
27+
"""
28+
29+
import hashlib
30+
import io
31+
import logging
32+
import os
33+
import random
34+
import re
35+
36+
import contextlib2
37+
from lxml import etree
38+
import numpy as np
39+
import PIL.Image
40+
import tensorflow as tf
41+
42+
from object_detection.dataset_tools import tf_record_creation_util
43+
from object_detection.utils import dataset_util
44+
from object_detection.utils import label_map_util
45+
46+
flags = tf.app.flags
47+
flags.DEFINE_string('data_dir', '', 'Path to root directory to dataset.')
48+
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
49+
flags.DEFINE_string('image_dir', 'JPEGImages', 'Name of the directory contatining images')
50+
flags.DEFINE_string('annotations_dir', 'Annotations', 'Name of the directory contatining Annotations')
51+
flags.DEFINE_string('label_map_path', '', 'Path to label map proto')
52+
flags.DEFINE_integer('num_shards', 1, 'Number of TFRecord shards')
53+
FLAGS = flags.FLAGS
54+
55+
# mask_pixel: dictionary containing class name and value for pixels belog to mask of each class
56+
# change as per your classes and labeling
57+
mask_pixel = {'speaker':25, 'cup':32}
58+
59+
def dict_to_tf_example(data,
60+
mask_path,
61+
label_map_dict,
62+
image_subdirectory,
63+
ignore_difficult_instances=False):
64+
"""Convert XML derived dict to tf.Example proto.
65+
66+
Notice that this function normalizes the bounding box coordinates provided
67+
by the raw data.
68+
69+
Args:
70+
data: dict holding PASCAL XML fields for a single image (obtained by
71+
running dataset_util.recursive_parse_xml_to_dict)
72+
mask_path: String path to PNG encoded mask.
73+
label_map_dict: A map from string label names to integers ids.
74+
image_subdirectory: String specifying subdirectory within the
75+
Pascal dataset directory holding the actual image data.
76+
ignore_difficult_instances: Whether to skip difficult instances in the
77+
dataset (default: False).
78+
79+
Returns:
80+
example: The converted tf.Example.
81+
82+
Raises:
83+
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
84+
"""
85+
img_path = os.path.join(image_subdirectory, data['filename'])
86+
with tf.gfile.GFile(img_path, 'rb') as fid:
87+
encoded_jpg = fid.read()
88+
encoded_jpg_io = io.BytesIO(encoded_jpg)
89+
image = PIL.Image.open(encoded_jpg_io)
90+
if image.format != 'JPEG':
91+
raise ValueError('Image format not JPEG')
92+
key = hashlib.sha256(encoded_jpg).hexdigest()
93+
94+
with tf.gfile.GFile(mask_path, 'rb') as fid:
95+
encoded_mask_png = fid.read()
96+
encoded_png_io = io.BytesIO(encoded_mask_png)
97+
mask = PIL.Image.open(encoded_png_io)
98+
mask_np = np.asarray(mask.convert('L'))
99+
if mask.format != 'PNG':
100+
raise ValueError('Mask format not PNG')
101+
102+
width = int(data['size']['width'])
103+
height = int(data['size']['height'])
104+
105+
xmins = []
106+
ymins = []
107+
xmaxs = []
108+
ymaxs = []
109+
classes = []
110+
classes_text = []
111+
truncated = []
112+
poses = []
113+
difficult_obj = []
114+
masks = []
115+
if 'object' in data:
116+
for obj in data['object']:
117+
class_name = obj['name']
118+
nonbackground_indices_x = np.any(mask_np == mask_pixel[class_name], axis=0)
119+
nonbackground_indices_y = np.any(mask_np == mask_pixel[class_name], axis=1)
120+
nonzero_x_indices = np.where(nonbackground_indices_x)
121+
nonzero_y_indices = np.where(nonbackground_indices_y)
122+
123+
difficult = bool(int(obj['difficult']))
124+
if ignore_difficult_instances and difficult:
125+
continue
126+
difficult_obj.append(int(difficult))
127+
128+
xmin = float(np.min(nonzero_x_indices))
129+
xmax = float(np.max(nonzero_x_indices))
130+
ymin = float(np.min(nonzero_y_indices))
131+
ymax = float(np.max(nonzero_y_indices))
132+
print(data['filename'], 'bounding box for', class_name, xmin, xmax, ymin, ymax)
133+
134+
xmins.append(xmin / width)
135+
ymins.append(ymin / height)
136+
xmaxs.append(xmax / width)
137+
ymaxs.append(ymax / height)
138+
139+
classes_text.append(class_name.encode('utf8'))
140+
classes.append(label_map_dict[class_name])
141+
truncated.append(int(obj['truncated']))
142+
poses.append(obj['pose'].encode('utf8'))
143+
144+
mask_remapped = (mask_np == mask_pixel[class_name]).astype(np.uint8)
145+
masks.append(mask_remapped)
146+
147+
feature_dict = {
148+
'image/height': dataset_util.int64_feature(height),
149+
'image/width': dataset_util.int64_feature(width),
150+
'image/filename': dataset_util.bytes_feature(
151+
data['filename'].encode('utf8')),
152+
'image/source_id': dataset_util.bytes_feature(
153+
data['filename'].encode('utf8')),
154+
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
155+
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
156+
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
157+
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
158+
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
159+
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
160+
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
161+
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
162+
'image/object/class/label': dataset_util.int64_list_feature(classes),
163+
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
164+
'image/object/truncated': dataset_util.int64_list_feature(truncated),
165+
'image/object/view': dataset_util.bytes_list_feature(poses),
166+
}
167+
168+
encoded_mask_png_list = []
169+
for mask in masks:
170+
img = PIL.Image.fromarray(mask)
171+
output = io.BytesIO()
172+
img.save(output, format='PNG')
173+
encoded_mask_png_list.append(output.getvalue())
174+
feature_dict['image/object/mask'] = (dataset_util.bytes_list_feature(encoded_mask_png_list))
175+
176+
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
177+
return example
178+
179+
180+
def create_tf_record(output_filename,
181+
num_shards,
182+
label_map_dict,
183+
annotations_dir,
184+
image_dir,
185+
examples):
186+
"""Creates a TFRecord file from examples.
187+
188+
Args:
189+
output_filename: Path to where output file is saved.
190+
num_shards: Number of shards for output file.
191+
label_map_dict: The label map dictionary.
192+
annotations_dir: Directory where annotation files are stored.
193+
image_dir: Directory where image files are stored.
194+
examples: Examples to parse and save to tf record.
195+
"""
196+
with contextlib2.ExitStack() as tf_record_close_stack:
197+
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
198+
tf_record_close_stack, output_filename, num_shards)
199+
for idx, example in enumerate(examples):
200+
if idx % 100 == 0:
201+
logging.info('On image %d of %d', idx, len(examples))
202+
xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')
203+
mask_path = os.path.join(annotations_dir, 'masks', example + '.png')
204+
205+
if not os.path.exists(xml_path):
206+
logging.warning('Could not find %s, ignoring example.', xml_path)
207+
continue
208+
with tf.gfile.GFile(xml_path, 'r') as fid:
209+
xml_str = fid.read()
210+
xml = etree.fromstring(xml_str)
211+
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
212+
213+
try:
214+
tf_example = dict_to_tf_example(
215+
data,
216+
mask_path,
217+
label_map_dict,
218+
image_dir)
219+
if tf_example:
220+
shard_idx = idx % num_shards
221+
output_tfrecords[shard_idx].write(tf_example.SerializeToString())
222+
print("done")
223+
except ValueError:
224+
logging.warning('Invalid example: %s, ignoring.', xml_path)
225+
226+
def main(_):
227+
data_dir = FLAGS.data_dir
228+
train_output_path = FLAGS.output_dir
229+
image_dir = os.path.join(data_dir, FLAGS.image_dir)
230+
annotations_dir = os.path.join(data_dir, FLAGS.annotations_dir)
231+
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
232+
233+
logging.info('Reading from dataset.')
234+
examples_list = os.listdir(image_dir)
235+
for el in examples_list:
236+
if el[-3:] !='jpg':
237+
del examples_list[examples_list.index(el)]
238+
for el in examples_list:
239+
examples_list[examples_list.index(el)] = el[0:-4]
240+
241+
create_tf_record(
242+
train_output_path,
243+
FLAGS.num_shards,
244+
label_map_dict,
245+
annotations_dir,
246+
image_dir,
247+
examples_list)
248+
249+
250+
if __name__ == '__main__':
251+
tf.app.run()

extra/result_images/multi_mask1.png

170 KB
Loading

extra/result_images/multi_mask2.png

204 KB
Loading

extra/result_images/multi_mask3.png

184 KB
Loading

extra/result_images/multi_mask4.png

181 KB
Loading

extra/result_images/multi_mask5.png

175 KB
Loading

extra/result_images/multi_mask6.png

136 KB
Loading

0 commit comments

Comments
 (0)