Skip to content

Commit b66b0b0

Browse files
authored
Explicit signatures for tflite. Using ideas from #9688 (#10248)
1 parent c636ea3 commit b66b0b0

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

research/audioset/yamnet/export.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,35 @@ def log(msg):
4444

4545

4646
class YAMNet(tf.Module):
47-
"''A TF2 Module wrapper around YAMNet."""
47+
"""A TF2 Module wrapper around YAMNet."""
4848
def __init__(self, weights_path, params):
4949
super().__init__()
5050
self._yamnet = yamnet.yamnet_frames_model(params)
5151
self._yamnet.load_weights(weights_path)
5252
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
5353

54-
@tf.function
54+
@tf.function(input_signature=[])
5555
def class_map_path(self):
5656
return self._class_map_asset.asset_path
5757

58-
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
58+
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
5959
def __call__(self, waveform):
60-
return self._yamnet(waveform)
60+
predictions, embeddings, log_mel_spectrogram = self._yamnet(waveform)
61+
62+
return {'predictions': predictions,
63+
'embeddings': embeddings,
64+
'log_mel_spectrogram': log_mel_spectrogram}
6165

6266

6367
def check_model(model_fn, class_map_path, params):
6468
yamnet_classes = yamnet.class_names(class_map_path)
6569

6670
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
6771
def clip_test(waveform, expected_class_name, top_n=10):
68-
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
72+
results = model_fn(waveform=waveform)
73+
predictions = results['predictions']
74+
embeddings = results['embeddings']
75+
log_mel_spectrogram = results['log_mel_spectrogram']
6976
clip_predictions = np.mean(predictions, axis=0)
7077
top_n_indices = np.argsort(clip_predictions)[-top_n:]
7178
top_n_scores = clip_predictions[top_n_indices]
@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
106113

107114
# Make TF2 SavedModel export.
108115
log('Making TF2 SavedModel export ...')
109-
tf.saved_model.save(yamnet, export_dir)
116+
tf.saved_model.save(
117+
yamnet, export_dir,
118+
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
110119
log('Done')
111120

112121
# Check export with TF-Hub in TF2.
@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
143152
log('Making TF-Lite SavedModel export ...')
144153
saved_model_dir = os.path.join(export_dir, 'saved_model')
145154
os.makedirs(saved_model_dir)
146-
tf.saved_model.save(yamnet, saved_model_dir)
155+
tf.saved_model.save(
156+
yamnet, saved_model_dir,
157+
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
147158
log('Done')
148159

149160
# Check that the export can be loaded and works.
@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
154165

155166
# Make a TF-Lite model from the SavedModel.
156167
log('Making TF-Lite model ...')
157-
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
168+
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
169+
saved_model_dir, signature_keys=['serving_default'])
158170
tflite_model = tflite_converter.convert()
159171
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
160172
with open(tflite_model_path, 'wb') as f:
@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
164176
# Check the TF-Lite export.
165177
log('Checking TF-Lite model ...')
166178
interpreter = tf.lite.Interpreter(tflite_model_path)
167-
audio_input_index = interpreter.get_input_details()[0]['index']
168-
scores_output_index = interpreter.get_output_details()[0]['index']
169-
embeddings_output_index = interpreter.get_output_details()[1]['index']
170-
spectrogram_output_index = interpreter.get_output_details()[2]['index']
171-
def run_model(waveform):
172-
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
173-
interpreter.allocate_tensors()
174-
interpreter.set_tensor(audio_input_index, waveform)
175-
interpreter.invoke()
176-
return (interpreter.get_tensor(scores_output_index),
177-
interpreter.get_tensor(embeddings_output_index),
178-
interpreter.get_tensor(spectrogram_output_index))
179-
check_model(run_model, 'yamnet_class_map.csv', params)
179+
runner = interpreter.get_signature_runner('serving_default')
180+
check_model(runner, 'yamnet_class_map.csv', params)
180181
log('Done')
181182

182183
return saved_model_dir

0 commit comments

Comments
 (0)