Skip to content

Commit 5e90c6c

Browse files
wangtzMarkDaoust
authored andcommitted
Model Maker: Add audio classification demo.
PiperOrigin-RevId: 352510173
1 parent b7728ba commit 5e90c6c

File tree

3 files changed

+184
-1
lines changed

3 files changed

+184
-1
lines changed

tensorflow_examples/lite/model_maker/core/test_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_srcdir():
4545

4646
def get_test_data_path(file_or_dirname):
4747
"""Return full test data path."""
48-
for directory, subdirs, files in tf.io.gfile.walk(test_srcdir()):
48+
for (directory, subdirs, files) in tf.io.gfile.walk(test_srcdir()):
4949
for f in subdirs + files:
5050
if f.endswith(file_or_dirname):
5151
return os.path.join(directory, f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Audio classification demo code of Model Maker for TFLite.
15+
16+
Example usage:
17+
python audio_classification_demo.py --export_dir=/tmp
18+
19+
Sample output:
20+
Downloading data from
21+
https://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip
22+
182083584/182082353 [==============================] - 4s 0us/step
23+
182091776/182082353 [==============================] - 4s 0us/step
24+
Dataset has been downloaded to
25+
/usr/local/google/home/wangtz/.keras/datasets/mini_speech_commands
26+
Processing audio files:
27+
8000/8000 [==============================] - 354s 44ms/file
28+
Cached 7178 audio samples.
29+
Training the model
30+
5742/5742 [==============================] - 29s 5ms/step - loss: 3.2289 - acc:
31+
0.8029 - val_loss: 0.6229 - val_acc: 0.9638
32+
Evaluating the model
33+
15/15 [==============================] - 2s 12ms/step - loss: 1.3569 - acc:
34+
0.9270
35+
Test accuracy: 0.927039
36+
"""
37+
38+
from __future__ import absolute_import
39+
from __future__ import division
40+
from __future__ import print_function
41+
42+
from absl import app
43+
from absl import flags
44+
from absl import logging
45+
46+
import tensorflow as tf
47+
from tensorflow_examples.lite.model_maker.core.data_util import audio_dataloader
48+
from tensorflow_examples.lite.model_maker.core.task import audio_classifier
49+
from tensorflow_examples.lite.model_maker.core.task import model_spec
50+
51+
FLAGS = flags.FLAGS
52+
53+
54+
def define_flags():
55+
flags.DEFINE_string('export_dir', None,
56+
'The directory to save exported files.')
57+
flags.DEFINE_string('spec', 'audio_browser_fft',
58+
'Name of the model spec to use.')
59+
flags.mark_flag_as_required('export_dir')
60+
61+
62+
def download_dataset(**kwargs):
63+
"""Downloads demo dataset, and returns directory path."""
64+
tf.compat.v1.logging.info('Downloading mini speech command dataset.')
65+
# ${HOME}/.keras/datasets/mini_speech_commands.zip
66+
filepath = tf.keras.utils.get_file(
67+
fname='mini_speech_commands.zip',
68+
origin='https://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip',
69+
extract=True,
70+
**kwargs)
71+
# ${HOME}/.keras/datasets/mini_speech_commands
72+
folder_path = filepath.rsplit('.', 1)[0]
73+
print(f'Dataset has been downloaded to {folder_path}')
74+
return folder_path
75+
76+
77+
def run(data_dir, export_dir, spec='audio_browser_fft', **kwargs):
78+
"""Runs demo."""
79+
spec = model_spec.get(spec)
80+
data = audio_dataloader.DataLoader.from_folder(spec, data_dir)
81+
82+
train_data, rest_data = data.split(0.8)
83+
validation_data, test_data = rest_data.split(0.5)
84+
85+
print('Training the model')
86+
model = audio_classifier.create(train_data, spec, validation_data, **kwargs)
87+
88+
print('Evaluating the model')
89+
_, acc = model.evaluate(test_data)
90+
print('Test accuracy: %f' % acc)
91+
92+
model.export(export_dir)
93+
94+
95+
def main(_):
96+
logging.set_verbosity(logging.INFO)
97+
data_dir = download_dataset()
98+
run(data_dir, FLAGS.export_dir)
99+
100+
101+
if __name__ == '__main__':
102+
define_flags()
103+
app.run(main)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import os
20+
import tempfile
21+
import unittest
22+
23+
import tensorflow as tf
24+
25+
from tensorflow_examples.lite.model_maker.core import test_util
26+
from tensorflow_examples.lite.model_maker.core.data_util.audio_dataloader import DataLoader
27+
from tensorflow_examples.lite.model_maker.demo import audio_classification_demo
28+
29+
30+
from_folder_fn = DataLoader.from_folder
31+
32+
33+
def patch_data_loader():
34+
"""Patch to train partial dataset rather than all of them."""
35+
36+
def side_effect(*args, **kwargs):
37+
tf.compat.v1.logging.info('Train on partial dataset')
38+
# This takes around 8 mins as it caches all files in the folder.
39+
# We should be able to address this issue once the dataset is lazily loaded.
40+
data_loader = from_folder_fn(*args, **kwargs)
41+
if len(data_loader) > 10: # Trim dataset to at most 10.
42+
data_loader._size = 10
43+
# TODO(b/171449557): Change this once the dataset is lazily loaded.
44+
data_loader._dataset = data_loader._dataset.take(10)
45+
return data_loader
46+
47+
return unittest.mock.patch.object(
48+
DataLoader, 'from_folder', side_effect=side_effect)
49+
50+
51+
class AudioClassificationDemoTest(tf.test.TestCase):
52+
53+
def test_audio_classification_demo(self):
54+
with patch_data_loader():
55+
with tempfile.TemporaryDirectory() as temp_dir:
56+
# Use cached training data if exists.
57+
data_dir = audio_classification_demo.download_dataset(
58+
cache_dir=test_util.get_cache_dir(temp_dir,
59+
'mini_speech_commands.zip'),
60+
file_hash='4b8a67bae2973844e84fa7ac988d1a44')
61+
62+
tflite_filename = os.path.join(temp_dir, 'model.tflite')
63+
label_filename = os.path.join(temp_dir, 'labels.txt')
64+
audio_classification_demo.run(
65+
data_dir,
66+
temp_dir,
67+
spec='audio_browser_fft',
68+
epochs=1,
69+
batch_size=1)
70+
71+
self.assertTrue(tf.io.gfile.exists(tflite_filename))
72+
self.assertGreater(os.path.getsize(tflite_filename), 0)
73+
74+
self.assertFalse(tf.io.gfile.exists(label_filename))
75+
76+
77+
if __name__ == '__main__':
78+
# Load compressed models from tensorflow_hub
79+
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
80+
tf.test.main()

0 commit comments

Comments
 (0)