Skip to content

Commit a40bab4

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add object_dectector task in TFLite Model Maker
PiperOrigin-RevId: 351309014
1 parent 66c18fb commit a40bab4

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ObjectDetector class."""
15+
16+
import tensorflow as tf
17+
from tensorflow_examples.lite.model_maker.core import compat
18+
from tensorflow_examples.lite.model_maker.core.task import custom_model
19+
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
20+
21+
22+
def create(train_data,
23+
model_spec,
24+
validation_data=None,
25+
epochs=None,
26+
batch_size=None,
27+
do_train=True):
28+
"""Loads data and train the model for test classification.
29+
30+
Args:
31+
train_data: Training data.
32+
model_spec: Specification for the model.
33+
validation_data: Validation data. If None, skips validation process.
34+
epochs: Number of epochs for training.
35+
batch_size: Batch size for training.
36+
do_train: Whether to run training.
37+
38+
Returns:
39+
TextClassifier
40+
"""
41+
model_spec = ms.get(model_spec)
42+
if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
43+
raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
44+
model_spec.compat_tf_versions, compat.get_tf_behavior()))
45+
46+
object_detector = ObjectDetector(model_spec, train_data.label_map)
47+
48+
if do_train:
49+
tf.compat.v1.logging.info('Retraining the models...')
50+
object_detector.train(train_data, validation_data, epochs, batch_size)
51+
else:
52+
object_detector.create_model()
53+
54+
return object_detector
55+
56+
57+
class ObjectDetector(custom_model.CustomModel):
58+
"""ObjectDetector class for inference and exporting to tflite."""
59+
60+
def __init__(self, model_spec, label_map):
61+
super().__init__(model_spec, shuffle=None)
62+
if model_spec.config.label_map and model_spec.config.label_map != label_map:
63+
tf.compat.v1.logging.warn(
64+
'Label map is not the same as the previous label_map in model_spec.')
65+
model_spec.config.label_map = label_map
66+
model_spec.config.num_classes = len(label_map)
67+
68+
def create_model(self):
69+
self.model = self.model_spec.create_model()
70+
return self.model
71+
72+
def _get_dataset_and_steps(self, data, batch_size, is_training):
73+
"""Gets dataset, steps and annotations json file."""
74+
if not data:
75+
return None, 0, None
76+
# TODO(b/171449557): Put this into DataLoader.
77+
dataset = data.gen_dataset(
78+
self.model_spec, batch_size, is_training=is_training)
79+
steps = len(data) // batch_size
80+
return dataset, steps, data.annotations_json_file
81+
82+
def train(self,
83+
train_data,
84+
validation_data=None,
85+
epochs=None,
86+
batch_size=None):
87+
"""Feeds the training data for training."""
88+
batch_size = batch_size if batch_size else self.model_spec.batch_size
89+
# TODO(b/171449557): Upstream this to the parent class.
90+
if len(train_data) < batch_size:
91+
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
92+
'than batch_size (%d). To solve this problem, set '
93+
'the batch_size smaller or increase the size of the '
94+
'train_data.' % (len(train_data), batch_size))
95+
96+
with self.model_spec.ds_strategy.scope():
97+
self.create_model()
98+
train_ds, steps_per_epoch, _ = self._get_dataset_and_steps(
99+
train_data, batch_size, is_training=True)
100+
validation_ds, validation_steps, val_json_file = self._get_dataset_and_steps(
101+
validation_data, batch_size, is_training=False)
102+
return self.model_spec.train(self.model, train_ds, steps_per_epoch,
103+
validation_ds, validation_steps, epochs,
104+
batch_size, val_json_file)
105+
106+
def evaluate(self, data, batch_size=None):
107+
"""Evaluates the model."""
108+
batch_size = batch_size if batch_size else self.model_spec.batch_size
109+
ds = data.gen_dataset(self.model_spec, batch_size, is_training=False)
110+
steps = len(data) // batch_size
111+
# TODO(b/171449557): Upstream this to the parent class.
112+
if steps <= 0:
113+
raise ValueError('The size of the validation_data (%d) couldn\'t be '
114+
'smaller than batch_size (%d). To solve this problem, '
115+
'set the batch_size smaller or increase the size of the '
116+
'validation_data.' % (len(data), batch_size))
117+
118+
return self.model_spec.evaluate(self.model, ds, steps,
119+
data.annotations_json_file)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import os
20+
21+
import tensorflow.compat.v2 as tf
22+
from tensorflow_examples.lite.model_maker.core import compat
23+
from tensorflow_examples.lite.model_maker.core import test_util
24+
from tensorflow_examples.lite.model_maker.core.data_util import object_detector_dataloader
25+
from tensorflow_examples.lite.model_maker.core.task import object_detector
26+
from tensorflow_examples.lite.model_maker.core.task.model_spec import object_detector_spec
27+
28+
29+
class ObjectDetectorTest(tf.test.TestCase):
30+
31+
def testEfficientDetLite0(self):
32+
# Gets model specification.
33+
hub_path = test_util.get_test_data_path('fake_effdet_lite0_hub')
34+
spec = object_detector_spec.EfficientDetModelSpec(
35+
model_name='efficientdet-lite0', uri=hub_path)
36+
37+
# Prepare data.
38+
images_dir, annotations_dir, label_map = test_util.create_pascal_voc(
39+
self.get_temp_dir())
40+
data = object_detector_dataloader.DataLoader.from_pascal_voc(
41+
images_dir, annotations_dir, label_map)
42+
43+
# Train the model.
44+
task = object_detector.create(data, spec, batch_size=1, epochs=1)
45+
46+
# Evaluate trained model
47+
metrics = task.evaluate(data, batch_size=1)
48+
self.assertIsInstance(metrics, dict)
49+
self.assertGreaterEqual(metrics['AP'], 0)
50+
51+
52+
if __name__ == '__main__':
53+
# Load compressed models from tensorflow_hub
54+
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
55+
compat.setup_tf_behavior(tf_version=2)
56+
tf.test.main()

0 commit comments

Comments
 (0)