1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ r"""Convert raw COCO dataset to TFRecord for object_detection.
17
+
18
+ Please note that this tool creates sharded output files.
19
+
20
+ Example usage:
21
+ python create_coco_tf_record.py --logtostderr \
22
+ --train_image_dir="${TRAIN_IMAGE_DIR}" \
23
+ --test_image_dir="${TEST_IMAGE_DIR}" \
24
+ --train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
25
+ --test_annotations_file="${TEST_ANNOTATIONS_FILE}" \
26
+ --output_dir="${OUTPUT_DIR}"
27
+ """
28
+ from __future__ import absolute_import
29
+ from __future__ import division
30
+ from __future__ import print_function
31
+
32
+ import hashlib
33
+ import io
34
+ import json
35
+ import os
36
+ import contextlib2
37
+ import numpy as np
38
+ import PIL .Image
39
+
40
+ from pycocotools import mask
41
+ import tensorflow as tf
42
+
43
+ from object_detection .dataset_tools import tf_record_creation_util
44
+ from object_detection .utils import dataset_util
45
+ from object_detection .utils import label_map_util
46
+
47
+
48
+ flags = tf .app .flags
49
+ tf .flags .DEFINE_boolean ('include_masks' , False ,
50
+ 'Whether to include instance segmentations masks '
51
+ '(PNG encoded) in the result. default: False.' )
52
+ tf .flags .DEFINE_string ('train_image_dir' , '' ,
53
+ 'Training image directory.' )
54
+ tf .flags .DEFINE_string ('test_image_dir' , '' ,
55
+ 'Test image directory.' )
56
+ tf .flags .DEFINE_string ('train_annotations_file' , '' ,
57
+ 'Training annotations JSON file.' )
58
+ tf .flags .DEFINE_string ('test_annotations_file' , '' ,
59
+ 'Test-dev annotations JSON file.' )
60
+ tf .flags .DEFINE_string ('output_dir' , '/tmp/' , 'Output data directory.' )
61
+
62
+ FLAGS = flags .FLAGS
63
+
64
+ tf .logging .set_verbosity (tf .logging .INFO )
65
+
66
+
67
+ def create_tf_example (image ,
68
+ annotations_list ,
69
+ image_dir ,
70
+ category_index ,
71
+ include_masks = False ):
72
+ """Converts image and annotations to a tf.Example proto.
73
+
74
+ Args:
75
+ image: dict with keys:
76
+ [u'license', u'file_name', u'coco_url', u'height', u'width',
77
+ u'date_captured', u'flickr_url', u'id']
78
+ annotations_list:
79
+ list of dicts with keys:
80
+ [u'segmentation', u'area', u'iscrowd', u'image_id',
81
+ u'bbox', u'category_id', u'id']
82
+ Notice that bounding box coordinates in the official COCO dataset are
83
+ given as [x, y, width, height] tuples using absolute coordinates where
84
+ x, y represent the top-left (0-indexed) corner. This function converts
85
+ to the format expected by the Tensorflow Object Detection API (which is
86
+ which is [ymin, xmin, ymax, xmax] with coordinates normalized relative
87
+ to image size).
88
+ image_dir: directory containing the image files.
89
+ category_index: a dict containing COCO category information keyed
90
+ by the 'id' field of each category. See the
91
+ label_map_util.create_category_index function.
92
+ include_masks: Whether to include instance segmentations masks
93
+ (PNG encoded) in the result. default: False.
94
+ Returns:
95
+ example: The converted tf.Example
96
+ num_annotations_skipped: Number of (invalid) annotations that were ignored.
97
+
98
+ Raises:
99
+ ValueError: if the image pointed to by data['filename'] is not a valid JPEG
100
+ """
101
+ image_height = image ['height' ]
102
+ image_width = image ['width' ]
103
+ filename = image ['file_name' ]
104
+ image_id = image ['id' ]
105
+
106
+ full_path = os .path .join (image_dir , filename )
107
+ with tf .gfile .GFile (full_path , 'rb' ) as fid :
108
+ encoded_jpg = fid .read ()
109
+ encoded_jpg_io = io .BytesIO (encoded_jpg )
110
+ image = PIL .Image .open (encoded_jpg_io )
111
+ key = hashlib .sha256 (encoded_jpg ).hexdigest ()
112
+
113
+ xmin = []
114
+ xmax = []
115
+ ymin = []
116
+ ymax = []
117
+ is_crowd = []
118
+ category_names = []
119
+ category_ids = []
120
+ area = []
121
+ encoded_mask_png = []
122
+ num_annotations_skipped = 0
123
+ for object_annotations in annotations_list :
124
+ (x , y , width , height ) = tuple (object_annotations ['bbox' ])
125
+ if width <= 0 or height <= 0 :
126
+ num_annotations_skipped += 1
127
+ continue
128
+ if x + width > image_width or y + height > image_height :
129
+ num_annotations_skipped += 1
130
+ continue
131
+ xmin .append (float (x ) / image_width )
132
+ xmax .append (float (x + width ) / image_width )
133
+ ymin .append (float (y ) / image_height )
134
+ ymax .append (float (y + height ) / image_height )
135
+ is_crowd .append (object_annotations ['iscrowd' ])
136
+ category_id = int (object_annotations ['category_id' ])
137
+ category_ids .append (category_id )
138
+ category_names .append (category_index [category_id ]['name' ].encode ('utf8' ))
139
+ area .append (object_annotations ['area' ])
140
+
141
+ if include_masks :
142
+ run_len_encoding = mask .frPyObjects (object_annotations ['segmentation' ],
143
+ image_height , image_width )
144
+ binary_mask = mask .decode (run_len_encoding )
145
+ if not object_annotations ['iscrowd' ]:
146
+ binary_mask = np .amax (binary_mask , axis = 2 )
147
+ pil_image = PIL .Image .fromarray (binary_mask )
148
+ output_io = io .BytesIO ()
149
+ pil_image .save (output_io , format = 'PNG' )
150
+ encoded_mask_png .append (output_io .getvalue ())
151
+ feature_dict = {
152
+ 'image/height' :
153
+ dataset_util .int64_feature (image_height ),
154
+ 'image/width' :
155
+ dataset_util .int64_feature (image_width ),
156
+ 'image/filename' :
157
+ dataset_util .bytes_feature (filename .encode ('utf8' )),
158
+ 'image/source_id' :
159
+ dataset_util .bytes_feature (str (image_id ).encode ('utf8' )),
160
+ 'image/key/sha256' :
161
+ dataset_util .bytes_feature (key .encode ('utf8' )),
162
+ 'image/encoded' :
163
+ dataset_util .bytes_feature (encoded_jpg ),
164
+ 'image/format' :
165
+ dataset_util .bytes_feature ('jpeg' .encode ('utf8' )),
166
+ 'image/object/bbox/xmin' :
167
+ dataset_util .float_list_feature (xmin ),
168
+ 'image/object/bbox/xmax' :
169
+ dataset_util .float_list_feature (xmax ),
170
+ 'image/object/bbox/ymin' :
171
+ dataset_util .float_list_feature (ymin ),
172
+ 'image/object/bbox/ymax' :
173
+ dataset_util .float_list_feature (ymax ),
174
+ 'image/object/class/text' :
175
+ dataset_util .bytes_list_feature (category_names ),
176
+ 'image/object/is_crowd' :
177
+ dataset_util .int64_list_feature (is_crowd ),
178
+ 'image/object/area' :
179
+ dataset_util .float_list_feature (area ),
180
+ }
181
+ if include_masks :
182
+ feature_dict ['image/object/mask' ] = (
183
+ dataset_util .bytes_list_feature (encoded_mask_png ))
184
+ example = tf .train .Example (features = tf .train .Features (feature = feature_dict ))
185
+ return key , example , num_annotations_skipped
186
+
187
+
188
+ def _create_tf_record_from_coco_annotations (
189
+ annotations_file , image_dir , output_path , include_masks ):
190
+ """Loads COCO annotation json files and converts to tf.Record format.
191
+
192
+ Args:
193
+ annotations_file: JSON file containing bounding box annotations.
194
+ image_dir: Directory containing the image files.
195
+ output_path: Path to output tf.Record file.
196
+ include_masks: Whether to include instance segmentations masks
197
+ (PNG encoded) in the result. default: False.
198
+ """
199
+ with tf .gfile .GFile (annotations_file , 'r' ) as fid :
200
+ output_tfrecords = tf .python_io .TFRecordWriter (output_path )
201
+ groundtruth_data = json .load (fid )
202
+ images = groundtruth_data ['images' ]
203
+ category_index = label_map_util .create_category_index (
204
+ groundtruth_data ['categories' ])
205
+
206
+ annotations_index = {}
207
+ if 'annotations' in groundtruth_data :
208
+ tf .logging .info (
209
+ 'Found groundtruth annotations. Building annotations index.' )
210
+ for annotation in groundtruth_data ['annotations' ]:
211
+ image_id = annotation ['image_id' ]
212
+ if image_id not in annotations_index :
213
+ annotations_index [image_id ] = []
214
+ annotations_index [image_id ].append (annotation )
215
+ missing_annotation_count = 0
216
+ for image in images :
217
+ image_id = image ['id' ]
218
+ if image_id not in annotations_index :
219
+ missing_annotation_count += 1
220
+ annotations_index [image_id ] = []
221
+ tf .logging .info ('%d images are missing annotations.' ,
222
+ missing_annotation_count )
223
+
224
+ total_num_annotations_skipped = 0
225
+ for idx , image in enumerate (images ):
226
+ if idx % 100 == 0 :
227
+ tf .logging .info ('On image %d of %d' , idx , len (images ))
228
+ annotations_list = annotations_index [image ['id' ]]
229
+ _ , tf_example , num_annotations_skipped = create_tf_example (
230
+ image , annotations_list , image_dir , category_index , include_masks )
231
+ total_num_annotations_skipped += num_annotations_skipped
232
+ output_tfrecords .write (tf_example .SerializeToString ())
233
+ tf .logging .info ('Finished writing, skipped %d annotations.' ,
234
+ total_num_annotations_skipped )
235
+
236
+
237
+ def main (_ ):
238
+ assert FLAGS .train_image_dir , '`train_image_dir` missing.'
239
+ assert FLAGS .test_image_dir , '`test_image_dir` missing.'
240
+ assert FLAGS .train_annotations_file , '`train_annotations_file` missing.'
241
+ assert FLAGS .test_annotations_file , '`test_annotations_file` missing.'
242
+
243
+ if not tf .gfile .IsDirectory (FLAGS .output_dir ):
244
+ tf .gfile .MakeDirs (FLAGS .output_dir )
245
+ train_output_path = os .path .join (FLAGS .output_dir , 'train.record' )
246
+ testdev_output_path = os .path .join (FLAGS .output_dir , 'test.record' )
247
+
248
+ _create_tf_record_from_coco_annotations (
249
+ FLAGS .train_annotations_file ,
250
+ FLAGS .train_image_dir ,
251
+ train_output_path ,
252
+ FLAGS .include_masks )
253
+ _create_tf_record_from_coco_annotations (
254
+ FLAGS .test_annotations_file ,
255
+ FLAGS .test_image_dir ,
256
+ testdev_output_path ,
257
+ FLAGS .include_masks )
258
+
259
+
260
+ if __name__ == '__main__' :
261
+ tf .app .run ()
0 commit comments