Skip to content

Commit 66c18fb

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add object detection model spec in TFLite Model Maker
PiperOrigin-RevId: 351298636
1 parent 997f8ce commit 66c18fb

File tree

6 files changed

+369
-0
lines changed

6 files changed

+369
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Model specification for object detection."""
15+
16+
import collections
17+
import os
18+
import tempfile
19+
20+
from absl import logging
21+
import tensorflow as tf
22+
from tensorflow_examples.lite.model_maker.core import compat
23+
24+
from tensorflow_examples.lite.model_maker.third_party.efficientdet import coco_metric
25+
from tensorflow_examples.lite.model_maker.third_party.efficientdet import hparams_config
26+
from tensorflow_examples.lite.model_maker.third_party.efficientdet import utils
27+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import label_util
28+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import postprocess
29+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
30+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
31+
32+
33+
def _get_ordered_label_map(label_map):
34+
"""Gets label_map as an OrderedDict instance with ids sorted."""
35+
if not label_map:
36+
return label_map
37+
ordered_label_map = collections.OrderedDict()
38+
for idx in sorted(label_map.keys()):
39+
ordered_label_map[idx] = label_map[idx]
40+
return ordered_label_map
41+
42+
43+
class EfficientDetModelSpec(object):
44+
"""A specification of the EfficientDet model."""
45+
46+
compat_tf_versions = compat.get_compat_tf_versions(2)
47+
48+
def __init__(self,
49+
model_name,
50+
uri,
51+
hparams='',
52+
model_dir=None,
53+
epochs=50,
54+
batch_size=64,
55+
steps_per_execution=1,
56+
moving_average_decay=0,
57+
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
58+
strategy=None,
59+
tpu=None,
60+
gcp_project=None,
61+
tpu_zone=None,
62+
use_xla=False,
63+
profile=False,
64+
debug=False,
65+
tf_random_seed=111111):
66+
"""Initialze an instance with model paramaters.
67+
68+
Args:
69+
model_name: Model name.
70+
uri: TF-Hub path/url to EfficientDet module.
71+
hparams: Hyperparameters used to overwrite default configuration. Can be
72+
1) Dict, contains parameter names and values; 2) String, Comma separated
73+
k=v pairs of hyperparameters; 3) String, yaml filename which's a module
74+
containing attributes to use as hyperparameters.
75+
model_dir: The location to save the model checkpoint files.
76+
epochs: Default training epochs.
77+
batch_size: Training & Evaluation batch size.
78+
steps_per_execution: Number of steps per training execution.
79+
moving_average_decay: Float. The decay to use for maintaining moving
80+
averages of the trained parameters.
81+
var_freeze_expr: Expression to freeze variables.
82+
strategy: A string specifying which distribution strategy to use.
83+
Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy.
84+
'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF
85+
default with OneDeviceStrategy.
86+
tpu: The Cloud TPU to use for training. This should be either the name
87+
used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470
88+
url.
89+
gcp_project: Project name for the Cloud TPU-enabled project. If not
90+
specified, we will attempt to automatically detect the GCE project from
91+
metadata.
92+
tpu_zone: GCE zone where the Cloud TPU is located in. If not specified, we
93+
will attempt to automatically detect the GCE project from metadata.
94+
use_xla: Use XLA even if strategy is not tpu. If strategy is tpu, always
95+
use XLA, and this flag has no effect.
96+
profile: Enable profile mode.
97+
debug: Enable debug mode.
98+
tf_random_seed: Fixed random seed for deterministic execution across runs
99+
for debugging.
100+
"""
101+
self.model_name = model_name
102+
self.uri = uri
103+
self.batch_size = batch_size
104+
config = hparams_config.get_efficientdet_config(model_name)
105+
config.override(hparams)
106+
config.image_size = utils.parse_image_size(config.image_size)
107+
config.var_freeze_expr = var_freeze_expr
108+
config.moving_average_decay = moving_average_decay
109+
if epochs:
110+
config.num_epochs = epochs
111+
112+
if use_xla and strategy != 'tpu':
113+
tf.config.optimizer.set_jit(True)
114+
for gpu in tf.config.list_physical_devices('GPU'):
115+
tf.config.experimental.set_memory_growth(gpu, True)
116+
117+
if debug:
118+
tf.config.experimental_run_functions_eagerly(True)
119+
tf.debugging.set_log_device_placement(True)
120+
os.environ['TF_DETERMINISTIC_OPS'] = '1'
121+
tf.random.set_seed(tf_random_seed)
122+
logging.set_verbosity(logging.DEBUG)
123+
124+
if strategy == 'tpu':
125+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
126+
tpu, zone=tpu_zone, project=gcp_project)
127+
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
128+
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
129+
ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
130+
logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
131+
tf.config.set_soft_device_placement(True)
132+
elif strategy == 'gpus':
133+
ds_strategy = tf.distribute.MirroredStrategy()
134+
logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
135+
else:
136+
if tf.config.list_physical_devices('GPU'):
137+
ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
138+
else:
139+
ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')
140+
141+
self.ds_strategy = ds_strategy
142+
143+
if model_dir is None:
144+
model_dir = tempfile.mkdtemp()
145+
params = dict(
146+
profile=profile,
147+
model_name=model_name,
148+
steps_per_execution=steps_per_execution,
149+
model_dir=model_dir,
150+
strategy=strategy,
151+
batch_size=batch_size,
152+
tf_random_seed=tf_random_seed,
153+
debug=debug)
154+
config.override(params, True)
155+
self.config = config
156+
157+
# set mixed precision policy by keras api.
158+
precision = utils.get_precision(config.strategy, config.mixed_precision)
159+
policy = tf.keras.mixed_precision.experimental.Policy(precision)
160+
tf.keras.mixed_precision.experimental.set_policy(policy)
161+
162+
def create_model(self):
163+
"""Creates the EfficientDet model."""
164+
return train_lib.EfficientDetNetTrainHub(
165+
config=self.config, hub_module_url=self.uri)
166+
167+
def train(self,
168+
model,
169+
train_dataset,
170+
steps_per_epoch,
171+
val_dataset,
172+
validation_steps,
173+
epochs=None,
174+
batch_size=None,
175+
val_json_file=None):
176+
"""Run EfficientDet training."""
177+
config = self.config
178+
if not epochs:
179+
epochs = config.num_epochs
180+
181+
if not batch_size:
182+
batch_size = config.batch_size
183+
184+
config.update(
185+
dict(
186+
steps_per_epoch=steps_per_epoch,
187+
eval_samples=batch_size * validation_steps,
188+
val_json_file=val_json_file,
189+
batch_size=batch_size))
190+
train.setup_model(model, config)
191+
train.init_experimental(config)
192+
model.fit(
193+
train_dataset,
194+
epochs=epochs,
195+
steps_per_epoch=steps_per_epoch,
196+
callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset),
197+
validation_data=val_dataset,
198+
validation_steps=validation_steps)
199+
return model
200+
201+
def evaluate(self, model, dataset, steps, json_file=None):
202+
"""Evaluate the EfficientDet keras model."""
203+
label_map = label_util.get_label_map(self.config.label_map)
204+
# Sorts label_map.keys since pycocotools.cocoeval uses sorted catIds
205+
# (category ids) in COCOeval class.
206+
label_map = _get_ordered_label_map(label_map)
207+
208+
evaluator = coco_metric.EvaluationMetric(
209+
filename=json_file, label_map=label_map)
210+
211+
evaluator.reset_states()
212+
dataset = dataset.take(steps)
213+
214+
@tf.function
215+
def _get_detections(images, labels):
216+
cls_outputs, box_outputs = model(images, training=False)
217+
detections = postprocess.generate_detections(self.config, cls_outputs,
218+
box_outputs,
219+
labels['image_scales'],
220+
labels['source_ids'])
221+
tf.numpy_function(evaluator.update_state, [
222+
labels['groundtruth_data'],
223+
postprocess.transform_detections(detections)
224+
], [])
225+
226+
dataset = self.ds_strategy.experimental_distribute_dataset(dataset)
227+
for (images, labels) in dataset:
228+
self.ds_strategy.run(_get_detections, (images, labels))
229+
230+
metrics = evaluator.result()
231+
metric_dict = {}
232+
for i, name in enumerate(evaluator.metric_names):
233+
metric_dict[name] = metrics[i]
234+
235+
if label_map:
236+
for i, cid in enumerate(label_map.keys()):
237+
name = 'AP_/%s' % label_map[cid]
238+
metric_dict[name] = metrics[i + len(evaluator.metric_names)]
239+
return metric_dict
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for object detector specs."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import math
21+
import os
22+
23+
import tensorflow.compat.v2 as tf
24+
25+
from tensorflow_examples.lite.model_maker.core import test_util
26+
from tensorflow_examples.lite.model_maker.core.task.model_spec import object_detector_spec
27+
28+
29+
class EfficientDetModelSpecTest(tf.test.TestCase):
30+
31+
@classmethod
32+
def setUpClass(cls):
33+
super(EfficientDetModelSpecTest, cls).setUpClass()
34+
hub_path = test_util.get_test_data_path('fake_effdet_lite0_hub')
35+
cls._spec = object_detector_spec.EfficientDetModelSpec(
36+
model_name='efficientdet-lite0', uri=hub_path, hparams=dict(map_freq=1))
37+
cls.model = cls._spec.create_model()
38+
39+
def test_create_model(self):
40+
self.assertIsInstance(self.model, tf.keras.Model)
41+
x = tf.ones((1, *self._spec.config.image_size, 3))
42+
cls_outputs, box_outputs = self.model(x)
43+
self.assertLen(cls_outputs, 5)
44+
self.assertLen(box_outputs, 5)
45+
46+
def test_train(self):
47+
model = self._spec.train(
48+
self.model,
49+
train_dataset=self._gen_input(),
50+
steps_per_epoch=1,
51+
val_dataset=self._gen_input(),
52+
validation_steps=1,
53+
epochs=1,
54+
batch_size=1)
55+
self.assertIsInstance(model, tf.keras.Model)
56+
57+
def test_evaluate(self):
58+
metrics = self._spec.evaluate(
59+
self.model, dataset=self._gen_input(), steps=1)
60+
self.assertIsInstance(metrics, dict)
61+
self.assertGreaterEqual(metrics['AP'], 0)
62+
63+
def _gen_input(self):
64+
# Image tensors that are preprocessed to have normalized value and fixed
65+
# dimension [1, image_height, image_width, 3]
66+
images = tf.random.uniform((1, 320, 320, 3), maxval=256)
67+
68+
# labels contains:
69+
# box_targets_dict: ordered dictionary with keys
70+
# [min_level, min_level+1, ..., max_level]. The values are tensor with
71+
# shape [height_l, width_l, num_anchors * 4]. The height_l and
72+
# width_l represent the dimension of bounding box regression output at
73+
# l-th level.
74+
# cls_targets_dict: ordered dictionary with keys
75+
# [min_level, min_level+1, ..., max_level]. The values are tensor with
76+
# shape [height_l, width_l, num_anchors]. The height_l and width_l
77+
# represent the dimension of class logits at l-th level.
78+
# groundtruth_data: Groundtruth Annotations data.
79+
# image_scale: Scale of the processed image to the original image.
80+
# source_id: Source image id. Default value -1 if the source id is empty
81+
# in the groundtruth annotation.
82+
# mean_num_positives: Mean number of positive anchors in the batch images.
83+
sizes = [(level, math.ceil(320 / 2**level)) for level in range(3, 8)]
84+
85+
labels = {
86+
'box_targets_%d' % level: tf.ones((1, size, size, 36))
87+
for level, size in sizes
88+
}
89+
labels.update({
90+
'cls_targets_%d' % level: tf.ones((1, size, size, 9), dtype=tf.int32)
91+
for level, size in sizes
92+
})
93+
labels.update({'groundtruth_data': tf.zeros([1, 100, 7])})
94+
labels.update({'image_scales': tf.constant([0.8])})
95+
labels.update({'source_ids': tf.constant([1.0])})
96+
labels.update({'mean_num_positives': tf.constant([10.0])})
97+
ds = tf.data.Dataset.from_tensors((images, labels))
98+
return ds
99+
100+
101+
if __name__ == '__main__':
102+
# Load compressed models from tensorflow_hub
103+
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
104+
tf.test.main()

0 commit comments

Comments
 (0)