Skip to content

Commit a5eb9a9

Browse files
ziyeqinghanMarkDaoust
authored andcommitted
Add export_saved_model and export_tflite for object detection in Model Maker
PiperOrigin-RevId: 352517354
1 parent 5e90c6c commit a5eb9a9

File tree

8 files changed

+159
-8
lines changed

8 files changed

+159
-8
lines changed

tensorflow_examples/lite/model_maker/core/task/configs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,13 @@ def create_float16_quantization(cls, optimizations=tf.lite.Optimize.DEFAULT):
179179
"""Creates configuration for float16 quantization."""
180180
return QuantizationConfig(optimizations, supported_types=[tf.float16])
181181

182-
def get_converter_with_quantization(self, converter, preprocess=None):
182+
def get_converter_with_quantization(self, converter, **kwargs):
183183
"""Gets TFLite converter with settings for quantization."""
184184
converter.optimizations = self.optimizations
185185

186186
if self.representative_data is not None:
187187
ds = self.representative_data.gen_dataset(
188-
batch_size=1, is_training=False, preprocess=preprocess)
188+
batch_size=1, is_training=False, **kwargs)
189189
converter.representative_dataset = tf.lite.RepresentativeDataset(
190190
_get_representative_dataset_gen(ds, self.quantization_steps))
191191

tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py

+92
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
from tensorflow_examples.lite.model_maker.third_party.efficientdet import coco_metric
2525
from tensorflow_examples.lite.model_maker.third_party.efficientdet import hparams_config
2626
from tensorflow_examples.lite.model_maker.third_party.efficientdet import utils
27+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import efficientdet_keras
28+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import inference
2729
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import label_util
2830
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import postprocess
2931
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
3032
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
33+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import util_keras
3134

3235

3336
def _get_ordered_label_map(label_map):
@@ -237,3 +240,92 @@ def _get_detections(images, labels):
237240
name = 'AP_/%s' % label_map[cid]
238241
metric_dict[name] = metrics[i + len(evaluator.metric_names)]
239242
return metric_dict
243+
244+
def export_saved_model(self,
245+
saved_model_dir,
246+
batch_size=None,
247+
pre_mode='infer',
248+
post_mode='global'):
249+
"""Saves the model to Tensorflow SavedModel.
250+
251+
Args:
252+
saved_model_dir: Folder path for saved model.
253+
batch_size: Batch size to be saved in saved_model.
254+
pre_mode: Pre-processing Mode in ExportModel, must be {None, 'infer'}.
255+
post_mode: Post-processing Mode in ExportModel, must be {None, 'global',
256+
'per_class'}.
257+
"""
258+
# Create EfficientDetModel with latest checkpoint.
259+
config = self.config
260+
model = efficientdet_keras.EfficientDetModel(config=config)
261+
model.build((batch_size, *config.image_size, 3))
262+
if config.model_dir:
263+
util_keras.restore_ckpt(
264+
model,
265+
config.model_dir,
266+
config['moving_average_decay'],
267+
skip_mismatch=False)
268+
else:
269+
# EfficientDetModel is random initialized without restoring the
270+
# checkpoint. This is mainly used in object_detector_test and shouldn't be
271+
# used if we want to export trained model.
272+
tf.compat.v1.logging.warn('Need to restore the checkpoint for '
273+
'EfficientDet.')
274+
# Gets tf.TensorSpec.
275+
if pre_mode is None:
276+
# Input is the preprocessed image that's already resized to a certain
277+
# input shape.
278+
input_spec = tf.TensorSpec(
279+
shape=[batch_size, *config.image_size, 3],
280+
dtype=tf.float32,
281+
name='images')
282+
else:
283+
# Input is that raw image that can be in any input shape,
284+
input_spec = tf.TensorSpec(
285+
shape=[batch_size, None, None, 3], dtype=tf.uint8, name='images')
286+
287+
export_model = inference.ExportModel(
288+
model, pre_mode=pre_mode, post_mode=post_mode)
289+
tf.saved_model.save(
290+
export_model,
291+
saved_model_dir,
292+
signatures=export_model.__call__.get_concrete_function(input_spec))
293+
294+
def export_tflite(self, tflite_filepath, quantization_config=None):
295+
"""Converts the retrained model to tflite format and saves it.
296+
297+
The exported TFLite model has the following inputs & outputs:
298+
One input:
299+
image: a float32 tensor of shape[1, height, width, 3] containing the
300+
normalized input image. `self.config.image_size` is [height, width].
301+
302+
Four Outputs:
303+
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
304+
locations.
305+
detection_classes: a float32 tensor of shape [1, num_boxes] with class
306+
indices.
307+
detection_scores: a float32 tensor of shape [1, num_boxes] with class
308+
scores.
309+
num_boxes: a float32 tensor of size 1 containing the number of detected
310+
boxes.
311+
312+
Args:
313+
tflite_filepath: File path to save tflite model.
314+
quantization_config: Configuration for post-training quantization.
315+
"""
316+
with tempfile.TemporaryDirectory() as temp_dir:
317+
self.export_saved_model(
318+
temp_dir, batch_size=1, pre_mode=None, post_mode='tflite')
319+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir)
320+
if quantization_config:
321+
converter = quantization_config.get_converter_with_quantization(
322+
converter, model_spec=self)
323+
324+
# TFLITE_BUILTINS is needed for TFLite's custom NMS op for integer only
325+
# quantization.
326+
if tf.lite.OpsSet.TFLITE_BUILTINS not in converter.target_spec.supported_ops:
327+
converter.target_spec.supported_ops += [tf.lite.OpsSet.TFLITE_BUILTINS]
328+
tflite_model = converter.convert()
329+
330+
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
331+
f.write(tflite_model)

tensorflow_examples/lite/model_maker/core/task/model_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def export_tflite(model,
9191

9292
if quantization_config:
9393
converter = quantization_config.get_converter_with_quantization(
94-
converter, preprocess)
94+
converter, preprocess=preprocess)
9595

9696
converter.target_spec.supported_ops = supported_ops
9797
tflite_model = converter.convert()

tensorflow_examples/lite/model_maker/core/task/object_detector.py

+13
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,16 @@ def evaluate(self, data, batch_size=None):
117117

118118
return self.model_spec.evaluate(self.model, ds, steps,
119119
data.annotations_json_file)
120+
121+
def _export_saved_model(self, saved_model_dir):
122+
"""Saves the model to Tensorflow SavedModel."""
123+
self.model_spec.export_saved_model(saved_model_dir)
124+
125+
def _export_tflite(self, tflite_filepath, quantization_config=None):
126+
"""Converts the retrained model to tflite format and saves it.
127+
128+
Args:
129+
tflite_filepath: File path to save tflite model.
130+
quantization_config: Configuration for post-training quantization.
131+
"""
132+
self.model_spec.export_tflite(tflite_filepath, quantization_config)

tensorflow_examples/lite/model_maker/core/task/object_detector_test.py

+34
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tensorflow_examples.lite.model_maker.core import compat
2323
from tensorflow_examples.lite.model_maker.core import test_util
2424
from tensorflow_examples.lite.model_maker.core.data_util import object_detector_dataloader
25+
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
26+
from tensorflow_examples.lite.model_maker.core.task import configs
2527
from tensorflow_examples.lite.model_maker.core.task import object_detector
2628
from tensorflow_examples.lite.model_maker.core.task.model_spec import object_detector_spec
2729

@@ -48,6 +50,38 @@ def testEfficientDetLite0(self):
4850
self.assertIsInstance(metrics, dict)
4951
self.assertGreaterEqual(metrics['AP'], 0)
5052

53+
# Export the model to saved model.
54+
spec.config.model_dir = None # Don't restore checkpoint.
55+
output_path = os.path.join(self.get_temp_dir(), 'saved_model')
56+
task.export(self.get_temp_dir(), export_format=ExportFormat.SAVED_MODEL)
57+
self.assertTrue(os.path.isdir(output_path))
58+
self.assertNotEqual(len(os.listdir(output_path)), 0)
59+
60+
# Export the model to TFLite model.
61+
output_path = os.path.join(self.get_temp_dir(), 'float.tflite')
62+
task.export(
63+
self.get_temp_dir(),
64+
tflite_filename='float.tflite',
65+
export_format=ExportFormat.TFLITE)
66+
self.assertTrue(tf.io.gfile.exists(output_path))
67+
self.assertGreater(os.path.getsize(output_path), 0)
68+
69+
# Export the model to quantized TFLite model.
70+
# TODO(b/175173304): Skips the test for stable tensorflow 2.4 for now since
71+
# it fails. Will revert this change after TF upgrade.
72+
if tf.__version__.startswith('2.4'):
73+
return
74+
output_path = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
75+
config = configs.QuantizationConfig.create_full_integer_quantization(
76+
data, is_integer_only=True)
77+
task.export(
78+
self.get_temp_dir(),
79+
tflite_filename='model_quantized.tflite',
80+
quantization_config=config,
81+
export_format=ExportFormat.TFLITE)
82+
self.assertTrue(os.path.isfile(output_path))
83+
self.assertGreater(os.path.getsize(output_path), 0)
84+
5185

5286
if __name__ == '__main__':
5387
# Load compressed models from tensorflow_hub

tensorflow_examples/lite/model_maker/pip_package/setup.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,10 @@ def _read_required_packages(fpath):
9191

9292
def get_required_packages():
9393
"""Gets packages inside requirements.txt."""
94-
# Gets third party's required packages.
95-
fpath = BASE_DIR.joinpath('third_party', 'efficientdet', 'requirements.txt')
96-
required_pkgs = _read_required_packages(fpath)
97-
9894
# Gets model maker's required packages
9995
filename = 'requirements_nightly.txt' if nightly else 'requirements.txt'
10096
fpath = BASE_DIR.joinpath(filename)
101-
required_pkgs += _read_required_packages(fpath)
97+
required_pkgs = _read_required_packages(fpath)
10298

10399
return required_pkgs
104100

tensorflow_examples/lite/model_maker/requirements.txt

+8
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ tensorflow>=2.4.0
1414
librosa>=0.5
1515
lxml>=4.6.1
1616
PyYAML>=5.1
17+
# The following are the requirements of efficientdet.
18+
matplotlib>=3.0.3
19+
six>=1.12.0
20+
tensorflow-addons>=0.11.2
21+
neural-structured-learning>=1.3.1
22+
tensorflow-model-optimization>=0.5
23+
Cython>=0.29.13
24+
pycocotools>=2.0.2

tensorflow_examples/lite/model_maker/requirements_nightly.txt

+8
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,11 @@ tensorflow>=2.4.0
1515
librosa>=0.5
1616
lxml>=4.6.1
1717
PyYAML>=5.1
18+
# The following are the requirements of efficientdet.
19+
matplotlib>=3.0.3
20+
six>=1.12.0
21+
tfa-nightly
22+
neural-structured-learning>=1.3.1
23+
tensorflow-model-optimization>=0.5
24+
Cython>=0.29.13
25+
pycocotools>=2.0.2

0 commit comments

Comments
 (0)