Skip to content

Commit 4915c84

Browse files
committed
fixed broken url in notebook, added trained model and file to convert from coco to tfrecord
1 parent d9c1e66 commit 4915c84

File tree

4 files changed

+264
-2
lines changed

4 files changed

+264
-2
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.idea

Diff for: Tensorflow_2_Object_Detection_Train_model.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
"metadata": {},
3535
"source": [
3636
"<table align=\"left\"><td>\n",
37-
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/TannerGilbert/Tutorials/blob/master/Tensorflow-Object-Detection-API-Train-Model/Tensorflow_2_Object_Detection_Train_model\">\n",
37+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/TannerGilbert/Tensorflow-Object-Detection-API-Train-Model/blob/master/Tensorflow_2_Object_Detection_Train_model.ipynb\">\n",
3838
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab\n",
3939
" </a>\n",
4040
"</td><td>\n",
41-
" <a target=\"_blank\" href=\"https://github.com/TannerGilbert/Tutorials/blob/master/Tensorflow-Object-Detection-API-Train-Model/Tensorflow_2_Object_Detection_Train_model\">\n",
41+
" <a target=\"_blank\" href=\"https://github.com/TannerGilbert/Tensorflow-Object-Detection-API-Train-Model/blob/master/Tensorflow_2_Object_Detection_Train_model.ipynb\">\n",
4242
" <img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
4343
"</td></table>"
4444
]

Diff for: create_coco_tf_record.py

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Convert raw COCO dataset to TFRecord for object_detection.
17+
18+
Please note that this tool creates sharded output files.
19+
20+
Example usage:
21+
python create_coco_tf_record.py --logtostderr \
22+
--train_image_dir="${TRAIN_IMAGE_DIR}" \
23+
--test_image_dir="${TEST_IMAGE_DIR}" \
24+
--train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
25+
--test_annotations_file="${TEST_ANNOTATIONS_FILE}" \
26+
--output_dir="${OUTPUT_DIR}"
27+
"""
28+
from __future__ import absolute_import
29+
from __future__ import division
30+
from __future__ import print_function
31+
32+
import hashlib
33+
import io
34+
import json
35+
import os
36+
import contextlib2
37+
import numpy as np
38+
import PIL.Image
39+
40+
from pycocotools import mask
41+
import tensorflow as tf
42+
43+
from object_detection.dataset_tools import tf_record_creation_util
44+
from object_detection.utils import dataset_util
45+
from object_detection.utils import label_map_util
46+
47+
48+
flags = tf.app.flags
49+
tf.flags.DEFINE_boolean('include_masks', False,
50+
'Whether to include instance segmentations masks '
51+
'(PNG encoded) in the result. default: False.')
52+
tf.flags.DEFINE_string('train_image_dir', '',
53+
'Training image directory.')
54+
tf.flags.DEFINE_string('test_image_dir', '',
55+
'Test image directory.')
56+
tf.flags.DEFINE_string('train_annotations_file', '',
57+
'Training annotations JSON file.')
58+
tf.flags.DEFINE_string('test_annotations_file', '',
59+
'Test-dev annotations JSON file.')
60+
tf.flags.DEFINE_string('output_dir', '/tmp/', 'Output data directory.')
61+
62+
FLAGS = flags.FLAGS
63+
64+
tf.logging.set_verbosity(tf.logging.INFO)
65+
66+
67+
def create_tf_example(image,
68+
annotations_list,
69+
image_dir,
70+
category_index,
71+
include_masks=False):
72+
"""Converts image and annotations to a tf.Example proto.
73+
74+
Args:
75+
image: dict with keys:
76+
[u'license', u'file_name', u'coco_url', u'height', u'width',
77+
u'date_captured', u'flickr_url', u'id']
78+
annotations_list:
79+
list of dicts with keys:
80+
[u'segmentation', u'area', u'iscrowd', u'image_id',
81+
u'bbox', u'category_id', u'id']
82+
Notice that bounding box coordinates in the official COCO dataset are
83+
given as [x, y, width, height] tuples using absolute coordinates where
84+
x, y represent the top-left (0-indexed) corner. This function converts
85+
to the format expected by the Tensorflow Object Detection API (which is
86+
which is [ymin, xmin, ymax, xmax] with coordinates normalized relative
87+
to image size).
88+
image_dir: directory containing the image files.
89+
category_index: a dict containing COCO category information keyed
90+
by the 'id' field of each category. See the
91+
label_map_util.create_category_index function.
92+
include_masks: Whether to include instance segmentations masks
93+
(PNG encoded) in the result. default: False.
94+
Returns:
95+
example: The converted tf.Example
96+
num_annotations_skipped: Number of (invalid) annotations that were ignored.
97+
98+
Raises:
99+
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
100+
"""
101+
image_height = image['height']
102+
image_width = image['width']
103+
filename = image['file_name']
104+
image_id = image['id']
105+
106+
full_path = os.path.join(image_dir, filename)
107+
with tf.gfile.GFile(full_path, 'rb') as fid:
108+
encoded_jpg = fid.read()
109+
encoded_jpg_io = io.BytesIO(encoded_jpg)
110+
image = PIL.Image.open(encoded_jpg_io)
111+
key = hashlib.sha256(encoded_jpg).hexdigest()
112+
113+
xmin = []
114+
xmax = []
115+
ymin = []
116+
ymax = []
117+
is_crowd = []
118+
category_names = []
119+
category_ids = []
120+
area = []
121+
encoded_mask_png = []
122+
num_annotations_skipped = 0
123+
for object_annotations in annotations_list:
124+
(x, y, width, height) = tuple(object_annotations['bbox'])
125+
if width <= 0 or height <= 0:
126+
num_annotations_skipped += 1
127+
continue
128+
if x + width > image_width or y + height > image_height:
129+
num_annotations_skipped += 1
130+
continue
131+
xmin.append(float(x) / image_width)
132+
xmax.append(float(x + width) / image_width)
133+
ymin.append(float(y) / image_height)
134+
ymax.append(float(y + height) / image_height)
135+
is_crowd.append(object_annotations['iscrowd'])
136+
category_id = int(object_annotations['category_id'])
137+
category_ids.append(category_id)
138+
category_names.append(category_index[category_id]['name'].encode('utf8'))
139+
area.append(object_annotations['area'])
140+
141+
if include_masks:
142+
run_len_encoding = mask.frPyObjects(object_annotations['segmentation'],
143+
image_height, image_width)
144+
binary_mask = mask.decode(run_len_encoding)
145+
if not object_annotations['iscrowd']:
146+
binary_mask = np.amax(binary_mask, axis=2)
147+
pil_image = PIL.Image.fromarray(binary_mask)
148+
output_io = io.BytesIO()
149+
pil_image.save(output_io, format='PNG')
150+
encoded_mask_png.append(output_io.getvalue())
151+
feature_dict = {
152+
'image/height':
153+
dataset_util.int64_feature(image_height),
154+
'image/width':
155+
dataset_util.int64_feature(image_width),
156+
'image/filename':
157+
dataset_util.bytes_feature(filename.encode('utf8')),
158+
'image/source_id':
159+
dataset_util.bytes_feature(str(image_id).encode('utf8')),
160+
'image/key/sha256':
161+
dataset_util.bytes_feature(key.encode('utf8')),
162+
'image/encoded':
163+
dataset_util.bytes_feature(encoded_jpg),
164+
'image/format':
165+
dataset_util.bytes_feature('jpeg'.encode('utf8')),
166+
'image/object/bbox/xmin':
167+
dataset_util.float_list_feature(xmin),
168+
'image/object/bbox/xmax':
169+
dataset_util.float_list_feature(xmax),
170+
'image/object/bbox/ymin':
171+
dataset_util.float_list_feature(ymin),
172+
'image/object/bbox/ymax':
173+
dataset_util.float_list_feature(ymax),
174+
'image/object/class/text':
175+
dataset_util.bytes_list_feature(category_names),
176+
'image/object/is_crowd':
177+
dataset_util.int64_list_feature(is_crowd),
178+
'image/object/area':
179+
dataset_util.float_list_feature(area),
180+
}
181+
if include_masks:
182+
feature_dict['image/object/mask'] = (
183+
dataset_util.bytes_list_feature(encoded_mask_png))
184+
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
185+
return key, example, num_annotations_skipped
186+
187+
188+
def _create_tf_record_from_coco_annotations(
189+
annotations_file, image_dir, output_path, include_masks):
190+
"""Loads COCO annotation json files and converts to tf.Record format.
191+
192+
Args:
193+
annotations_file: JSON file containing bounding box annotations.
194+
image_dir: Directory containing the image files.
195+
output_path: Path to output tf.Record file.
196+
include_masks: Whether to include instance segmentations masks
197+
(PNG encoded) in the result. default: False.
198+
"""
199+
with tf.gfile.GFile(annotations_file, 'r') as fid:
200+
output_tfrecords = tf.python_io.TFRecordWriter(output_path)
201+
groundtruth_data = json.load(fid)
202+
images = groundtruth_data['images']
203+
category_index = label_map_util.create_category_index(
204+
groundtruth_data['categories'])
205+
206+
annotations_index = {}
207+
if 'annotations' in groundtruth_data:
208+
tf.logging.info(
209+
'Found groundtruth annotations. Building annotations index.')
210+
for annotation in groundtruth_data['annotations']:
211+
image_id = annotation['image_id']
212+
if image_id not in annotations_index:
213+
annotations_index[image_id] = []
214+
annotations_index[image_id].append(annotation)
215+
missing_annotation_count = 0
216+
for image in images:
217+
image_id = image['id']
218+
if image_id not in annotations_index:
219+
missing_annotation_count += 1
220+
annotations_index[image_id] = []
221+
tf.logging.info('%d images are missing annotations.',
222+
missing_annotation_count)
223+
224+
total_num_annotations_skipped = 0
225+
for idx, image in enumerate(images):
226+
if idx % 100 == 0:
227+
tf.logging.info('On image %d of %d', idx, len(images))
228+
annotations_list = annotations_index[image['id']]
229+
_, tf_example, num_annotations_skipped = create_tf_example(
230+
image, annotations_list, image_dir, category_index, include_masks)
231+
total_num_annotations_skipped += num_annotations_skipped
232+
output_tfrecords.write(tf_example.SerializeToString())
233+
tf.logging.info('Finished writing, skipped %d annotations.',
234+
total_num_annotations_skipped)
235+
236+
237+
def main(_):
238+
assert FLAGS.train_image_dir, '`train_image_dir` missing.'
239+
assert FLAGS.test_image_dir, '`test_image_dir` missing.'
240+
assert FLAGS.train_annotations_file, '`train_annotations_file` missing.'
241+
assert FLAGS.test_annotations_file, '`test_annotations_file` missing.'
242+
243+
if not tf.gfile.IsDirectory(FLAGS.output_dir):
244+
tf.gfile.MakeDirs(FLAGS.output_dir)
245+
train_output_path = os.path.join(FLAGS.output_dir, 'train.record')
246+
testdev_output_path = os.path.join(FLAGS.output_dir, 'test.record')
247+
248+
_create_tf_record_from_coco_annotations(
249+
FLAGS.train_annotations_file,
250+
FLAGS.train_image_dir,
251+
train_output_path,
252+
FLAGS.include_masks)
253+
_create_tf_record_from_coco_annotations(
254+
FLAGS.test_annotations_file,
255+
FLAGS.test_image_dir,
256+
testdev_output_path,
257+
FLAGS.include_masks)
258+
259+
260+
if __name__ == '__main__':
261+
tf.app.run()

Diff for: training/saved_model.pb

21.1 MB
Binary file not shown.

0 commit comments

Comments
 (0)