|
24 | 24 | from tensorflow_examples.lite.model_maker.third_party.efficientdet import coco_metric
|
25 | 25 | from tensorflow_examples.lite.model_maker.third_party.efficientdet import hparams_config
|
26 | 26 | from tensorflow_examples.lite.model_maker.third_party.efficientdet import utils
|
| 27 | +from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import efficientdet_keras |
| 28 | +from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import inference |
27 | 29 | from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import label_util
|
28 | 30 | from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import postprocess
|
29 | 31 | from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
|
30 | 32 | from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
|
| 33 | +from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import util_keras |
31 | 34 |
|
32 | 35 |
|
33 | 36 | def _get_ordered_label_map(label_map):
|
@@ -237,3 +240,92 @@ def _get_detections(images, labels):
|
237 | 240 | name = 'AP_/%s' % label_map[cid]
|
238 | 241 | metric_dict[name] = metrics[i + len(evaluator.metric_names)]
|
239 | 242 | return metric_dict
|
| 243 | + |
| 244 | + def export_saved_model(self, |
| 245 | + saved_model_dir, |
| 246 | + batch_size=None, |
| 247 | + pre_mode='infer', |
| 248 | + post_mode='global'): |
| 249 | + """Saves the model to Tensorflow SavedModel. |
| 250 | +
|
| 251 | + Args: |
| 252 | + saved_model_dir: Folder path for saved model. |
| 253 | + batch_size: Batch size to be saved in saved_model. |
| 254 | + pre_mode: Pre-processing Mode in ExportModel, must be {None, 'infer'}. |
| 255 | + post_mode: Post-processing Mode in ExportModel, must be {None, 'global', |
| 256 | + 'per_class'}. |
| 257 | + """ |
| 258 | + # Create EfficientDetModel with latest checkpoint. |
| 259 | + config = self.config |
| 260 | + model = efficientdet_keras.EfficientDetModel(config=config) |
| 261 | + model.build((batch_size, *config.image_size, 3)) |
| 262 | + if config.model_dir: |
| 263 | + util_keras.restore_ckpt( |
| 264 | + model, |
| 265 | + config.model_dir, |
| 266 | + config['moving_average_decay'], |
| 267 | + skip_mismatch=False) |
| 268 | + else: |
| 269 | + # EfficientDetModel is random initialized without restoring the |
| 270 | + # checkpoint. This is mainly used in object_detector_test and shouldn't be |
| 271 | + # used if we want to export trained model. |
| 272 | + tf.compat.v1.logging.warn('Need to restore the checkpoint for ' |
| 273 | + 'EfficientDet.') |
| 274 | + # Gets tf.TensorSpec. |
| 275 | + if pre_mode is None: |
| 276 | + # Input is the preprocessed image that's already resized to a certain |
| 277 | + # input shape. |
| 278 | + input_spec = tf.TensorSpec( |
| 279 | + shape=[batch_size, *config.image_size, 3], |
| 280 | + dtype=tf.float32, |
| 281 | + name='images') |
| 282 | + else: |
| 283 | + # Input is that raw image that can be in any input shape, |
| 284 | + input_spec = tf.TensorSpec( |
| 285 | + shape=[batch_size, None, None, 3], dtype=tf.uint8, name='images') |
| 286 | + |
| 287 | + export_model = inference.ExportModel( |
| 288 | + model, pre_mode=pre_mode, post_mode=post_mode) |
| 289 | + tf.saved_model.save( |
| 290 | + export_model, |
| 291 | + saved_model_dir, |
| 292 | + signatures=export_model.__call__.get_concrete_function(input_spec)) |
| 293 | + |
| 294 | + def export_tflite(self, tflite_filepath, quantization_config=None): |
| 295 | + """Converts the retrained model to tflite format and saves it. |
| 296 | +
|
| 297 | + The exported TFLite model has the following inputs & outputs: |
| 298 | + One input: |
| 299 | + image: a float32 tensor of shape[1, height, width, 3] containing the |
| 300 | + normalized input image. `self.config.image_size` is [height, width]. |
| 301 | +
|
| 302 | + Four Outputs: |
| 303 | + detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box |
| 304 | + locations. |
| 305 | + detection_classes: a float32 tensor of shape [1, num_boxes] with class |
| 306 | + indices. |
| 307 | + detection_scores: a float32 tensor of shape [1, num_boxes] with class |
| 308 | + scores. |
| 309 | + num_boxes: a float32 tensor of size 1 containing the number of detected |
| 310 | + boxes. |
| 311 | +
|
| 312 | + Args: |
| 313 | + tflite_filepath: File path to save tflite model. |
| 314 | + quantization_config: Configuration for post-training quantization. |
| 315 | + """ |
| 316 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 317 | + self.export_saved_model( |
| 318 | + temp_dir, batch_size=1, pre_mode=None, post_mode='tflite') |
| 319 | + converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir) |
| 320 | + if quantization_config: |
| 321 | + converter = quantization_config.get_converter_with_quantization( |
| 322 | + converter, model_spec=self) |
| 323 | + |
| 324 | + # TFLITE_BUILTINS is needed for TFLite's custom NMS op for integer only |
| 325 | + # quantization. |
| 326 | + if tf.lite.OpsSet.TFLITE_BUILTINS not in converter.target_spec.supported_ops: |
| 327 | + converter.target_spec.supported_ops += [tf.lite.OpsSet.TFLITE_BUILTINS] |
| 328 | + tflite_model = converter.convert() |
| 329 | + |
| 330 | + with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: |
| 331 | + f.write(tflite_model) |
0 commit comments