@@ -44,28 +44,35 @@ def log(msg):
44
44
45
45
46
46
class YAMNet (tf .Module ):
47
- "'' A TF2 Module wrapper around YAMNet." ""
47
+ """ A TF2 Module wrapper around YAMNet."""
48
48
def __init__ (self , weights_path , params ):
49
49
super ().__init__ ()
50
50
self ._yamnet = yamnet .yamnet_frames_model (params )
51
51
self ._yamnet .load_weights (weights_path )
52
52
self ._class_map_asset = tf .saved_model .Asset ('yamnet_class_map.csv' )
53
53
54
- @tf .function
54
+ @tf .function ( input_signature = [])
55
55
def class_map_path (self ):
56
56
return self ._class_map_asset .asset_path
57
57
58
- @tf .function (input_signature = ( tf .TensorSpec (shape = [None ], dtype = tf .float32 ),) )
58
+ @tf .function (input_signature = [ tf .TensorSpec (shape = [None ], dtype = tf .float32 )] )
59
59
def __call__ (self , waveform ):
60
- return self ._yamnet (waveform )
60
+ predictions , embeddings , log_mel_spectrogram = self ._yamnet (waveform )
61
+
62
+ return {'predictions' : predictions ,
63
+ 'embeddings' : embeddings ,
64
+ 'log_mel_spectrogram' : log_mel_spectrogram }
61
65
62
66
63
67
def check_model (model_fn , class_map_path , params ):
64
68
yamnet_classes = yamnet .class_names (class_map_path )
65
69
66
70
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
67
71
def clip_test (waveform , expected_class_name , top_n = 10 ):
68
- predictions , embeddings , log_mel_spectrogram = model_fn (waveform )
72
+ results = model_fn (waveform = waveform )
73
+ predictions = results ['predictions' ]
74
+ embeddings = results ['embeddings' ]
75
+ log_mel_spectrogram = results ['log_mel_spectrogram' ]
69
76
clip_predictions = np .mean (predictions , axis = 0 )
70
77
top_n_indices = np .argsort (clip_predictions )[- top_n :]
71
78
top_n_scores = clip_predictions [top_n_indices ]
@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
106
113
107
114
# Make TF2 SavedModel export.
108
115
log ('Making TF2 SavedModel export ...' )
109
- tf .saved_model .save (yamnet , export_dir )
116
+ tf .saved_model .save (
117
+ yamnet , export_dir ,
118
+ signatures = {'serving_default' : yamnet .__call__ .get_concrete_function ()})
110
119
log ('Done' )
111
120
112
121
# Check export with TF-Hub in TF2.
@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
143
152
log ('Making TF-Lite SavedModel export ...' )
144
153
saved_model_dir = os .path .join (export_dir , 'saved_model' )
145
154
os .makedirs (saved_model_dir )
146
- tf .saved_model .save (yamnet , saved_model_dir )
155
+ tf .saved_model .save (
156
+ yamnet , saved_model_dir ,
157
+ signatures = {'serving_default' : yamnet .__call__ .get_concrete_function ()})
147
158
log ('Done' )
148
159
149
160
# Check that the export can be loaded and works.
@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
154
165
155
166
# Make a TF-Lite model from the SavedModel.
156
167
log ('Making TF-Lite model ...' )
157
- tflite_converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
168
+ tflite_converter = tf .lite .TFLiteConverter .from_saved_model (
169
+ saved_model_dir , signature_keys = ['serving_default' ])
158
170
tflite_model = tflite_converter .convert ()
159
171
tflite_model_path = os .path .join (export_dir , 'yamnet.tflite' )
160
172
with open (tflite_model_path , 'wb' ) as f :
@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
164
176
# Check the TF-Lite export.
165
177
log ('Checking TF-Lite model ...' )
166
178
interpreter = tf .lite .Interpreter (tflite_model_path )
167
- audio_input_index = interpreter .get_input_details ()[0 ]['index' ]
168
- scores_output_index = interpreter .get_output_details ()[0 ]['index' ]
169
- embeddings_output_index = interpreter .get_output_details ()[1 ]['index' ]
170
- spectrogram_output_index = interpreter .get_output_details ()[2 ]['index' ]
171
- def run_model (waveform ):
172
- interpreter .resize_tensor_input (audio_input_index , [len (waveform )], strict = True )
173
- interpreter .allocate_tensors ()
174
- interpreter .set_tensor (audio_input_index , waveform )
175
- interpreter .invoke ()
176
- return (interpreter .get_tensor (scores_output_index ),
177
- interpreter .get_tensor (embeddings_output_index ),
178
- interpreter .get_tensor (spectrogram_output_index ))
179
- check_model (run_model , 'yamnet_class_map.csv' , params )
179
+ runner = interpreter .get_signature_runner ('serving_default' )
180
+ check_model (runner , 'yamnet_class_map.csv' , params )
180
181
log ('Done' )
181
182
182
183
return saved_model_dir
0 commit comments