Skip to content

Commit 2dc6b91

Browse files
authored
Add deep speech model and run loop (#4632)
* Add deep speech model and run loop * fix lints and add init * Add dataset and address comments
1 parent fb3fba0 commit 2dc6b91

File tree

7 files changed

+850
-0
lines changed

7 files changed

+850
-0
lines changed

research/deep_speech/__init__.py

Whitespace-only changes.

research/deep_speech/data/__init__.py

Whitespace-only changes.

research/deep_speech/data/dataset.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Generate tf.data.Dataset object for deep speech training/evaluation."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import numpy as np
21+
import scipy.io.wavfile as wavfile
22+
from six.moves import xrange # pylint: disable=redefined-builtin
23+
import tensorflow as tf
24+
25+
# pylint: disable=g-bad-import-order
26+
from data.featurizer import AudioFeaturizer
27+
from data.featurizer import TextFeaturizer
28+
29+
30+
class AudioConfig(object):
31+
"""Configs for spectrogram extraction from audio."""
32+
33+
def __init__(self,
34+
sample_rate,
35+
frame_length,
36+
frame_step,
37+
fft_length=None,
38+
normalize=False,
39+
spect_type="linear"):
40+
"""Initialize the AudioConfig class.
41+
42+
Args:
43+
sample_rate: an integer denoting the sample rate of the input waveform.
44+
frame_length: an integer for the length of a spectrogram frame, in ms.
45+
frame_step: an integer for the frame stride, in ms.
46+
fft_length: an integer for the number of fft bins.
47+
normalize: a boolean for whether apply normalization on the audio tensor.
48+
spect_type: a string for the type of spectrogram to be extracted.
49+
"""
50+
51+
self.sample_rate = sample_rate
52+
self.frame_length = frame_length
53+
self.frame_step = frame_step
54+
self.fft_length = fft_length
55+
self.normalize = normalize
56+
self.spect_type = spect_type
57+
58+
59+
class DatasetConfig(object):
60+
"""Config class for generating the DeepSpeechDataset."""
61+
62+
def __init__(self, audio_config, data_path, vocab_file_path):
63+
"""Initialize the configs for deep speech dataset.
64+
65+
Args:
66+
audio_config: AudioConfig object specifying the audio-related configs.
67+
data_path: a string denoting the full path of a manifest file.
68+
vocab_file_path: a string specifying the vocabulary file path.
69+
70+
Raises:
71+
RuntimeError: file path not exist.
72+
"""
73+
74+
self.audio_config = audio_config
75+
assert tf.gfile.Exists(data_path)
76+
assert tf.gfile.Exists(vocab_file_path)
77+
self.data_path = data_path
78+
self.vocab_file_path = vocab_file_path
79+
80+
81+
class DeepSpeechDataset(object):
82+
"""Dataset class for training/evaluation of DeepSpeech model."""
83+
84+
def __init__(self, dataset_config):
85+
"""Initialize the class.
86+
87+
Each dataset file contains three columns: "wav_filename", "wav_filesize",
88+
and "transcript". This function parses the csv file and stores each example
89+
by the increasing order of audio length (indicated by wav_filesize).
90+
91+
Args:
92+
dataset_config: DatasetConfig object.
93+
"""
94+
self.config = dataset_config
95+
# Instantiate audio feature extractor.
96+
self.audio_featurizer = AudioFeaturizer(
97+
sample_rate=self.config.audio_config.sample_rate,
98+
frame_length=self.config.audio_config.frame_length,
99+
frame_step=self.config.audio_config.frame_step,
100+
fft_length=self.config.audio_config.fft_length,
101+
spect_type=self.config.audio_config.spect_type)
102+
# Instantiate text feature extractor.
103+
self.text_featurizer = TextFeaturizer(
104+
vocab_file=self.config.vocab_file_path)
105+
106+
self.speech_labels = self.text_featurizer.speech_labels
107+
self.features, self.labels = self._preprocess_data(self.config.data_path)
108+
self.num_feature_bins = (
109+
self.features[0].shape[1] if len(self.features) else None)
110+
111+
def _preprocess_data(self, file_path):
112+
"""Generate a list of waveform, transcript pair.
113+
114+
Note that the waveforms are ordered in increasing length, so that audio
115+
samples in a mini-batch have similar length.
116+
117+
Args:
118+
file_path: a string specifying the csv file path for a data set.
119+
120+
Returns:
121+
features and labels array processed from the audio/text input.
122+
"""
123+
124+
with tf.gfile.Open(file_path, "r") as f:
125+
lines = f.read().splitlines()
126+
lines = [line.split("\t") for line in lines]
127+
# Skip the csv header.
128+
lines = lines[1:]
129+
# Sort input data by the length of waveform.
130+
lines.sort(key=lambda item: int(item[1]))
131+
features = [self._preprocess_audio(line[0]) for line in lines]
132+
labels = [self._preprocess_transcript(line[2]) for line in lines]
133+
return features, labels
134+
135+
def _normalize_audio_tensor(self, audio_tensor):
136+
"""Perform mean and variance normalization on the spectrogram tensor.
137+
138+
Args:
139+
audio_tensor: a tensor for the spectrogram feature.
140+
141+
Returns:
142+
a tensor for the normalized spectrogram.
143+
"""
144+
mean, var = tf.nn.moments(audio_tensor, axes=[0])
145+
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6)
146+
return normalized
147+
148+
def _preprocess_audio(self, audio_file_path):
149+
"""Load the audio file in memory."""
150+
tf.logging.info(
151+
"Extracting spectrogram feature for {}".format(audio_file_path))
152+
sample_rate, data = wavfile.read(audio_file_path)
153+
assert sample_rate == self.config.audio_config.sample_rate
154+
if data.dtype not in [np.float32, np.float64]:
155+
data = data.astype(np.float32) / np.iinfo(data.dtype).max
156+
feature = self.audio_featurizer.featurize(data)
157+
if self.config.audio_config.normalize:
158+
feature = self._normalize_audio_tensor(feature)
159+
return tf.Session().run(
160+
feature) # return a numpy array rather than a tensor
161+
162+
def _preprocess_transcript(self, transcript):
163+
return self.text_featurizer.featurize(transcript)
164+
165+
166+
def input_fn(batch_size, deep_speech_dataset, repeat=1):
167+
"""Input function for model training and evaluation.
168+
169+
Args:
170+
batch_size: an integer denoting the size of a batch.
171+
deep_speech_dataset: DeepSpeechDataset object.
172+
repeat: an integer for how many times to repeat the dataset.
173+
174+
Returns:
175+
a tf.data.Dataset object for model to consume.
176+
"""
177+
features = deep_speech_dataset.features
178+
labels = deep_speech_dataset.labels
179+
num_feature_bins = deep_speech_dataset.num_feature_bins
180+
181+
def _gen_data():
182+
for i in xrange(len(features)):
183+
feature = np.expand_dims(features[i], axis=2)
184+
input_length = [features[i].shape[0]]
185+
label_length = [len(labels[i])]
186+
yield {
187+
"features": feature,
188+
"labels": labels[i],
189+
"input_length": input_length,
190+
"label_length": label_length
191+
}
192+
193+
dataset = tf.data.Dataset.from_generator(
194+
_gen_data,
195+
output_types={
196+
"features": tf.float32,
197+
"labels": tf.int32,
198+
"input_length": tf.int32,
199+
"label_length": tf.int32
200+
},
201+
output_shapes={
202+
"features": tf.TensorShape([None, num_feature_bins, 1]),
203+
"labels": tf.TensorShape([None]),
204+
"input_length": tf.TensorShape([1]),
205+
"label_length": tf.TensorShape([1])
206+
})
207+
208+
# Repeat and batch the dataset
209+
dataset = dataset.repeat(repeat)
210+
# Padding the features to its max length dimensions.
211+
dataset = dataset.padded_batch(
212+
batch_size=batch_size,
213+
padded_shapes={
214+
"features": tf.TensorShape([None, num_feature_bins, 1]),
215+
"labels": tf.TensorShape([None]),
216+
"input_length": tf.TensorShape([1]),
217+
"label_length": tf.TensorShape([1])
218+
})
219+
220+
# Prefetch to improve speed of input pipeline.
221+
dataset = dataset.prefetch(1)
222+
return dataset
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility class for extracting features from the text and audio input."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import codecs
21+
import functools
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
26+
class AudioFeaturizer(object):
27+
"""Class to extract spectrogram features from the audio input."""
28+
29+
def __init__(self,
30+
sample_rate=16000,
31+
frame_length=25,
32+
frame_step=10,
33+
fft_length=None,
34+
window_fn=functools.partial(
35+
tf.contrib.signal.hann_window, periodic=True),
36+
spect_type="linear"):
37+
"""Initialize the audio featurizer class according to the configs.
38+
39+
Args:
40+
sample_rate: an integer specifying the sample rate of the input waveform.
41+
frame_length: an integer for the length of a spectrogram frame, in ms.
42+
frame_step: an integer for the frame stride, in ms.
43+
fft_length: an integer for the number of fft bins.
44+
window_fn: windowing function.
45+
spect_type: a string for the type of spectrogram to be extracted.
46+
Currently only support 'linear', otherwise will raise a value error.
47+
48+
Raises:
49+
ValueError: In case of invalid arguments for `spect_type`.
50+
"""
51+
if spect_type != "linear":
52+
raise ValueError("Unsupported spectrogram type: %s" % spect_type)
53+
self.window_fn = window_fn
54+
self.frame_length = int(sample_rate * frame_length / 1e3)
55+
self.frame_step = int(sample_rate * frame_step / 1e3)
56+
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
57+
np.log2(self.frame_length))))
58+
59+
def featurize(self, waveform):
60+
"""Extract spectrogram feature tensors from the waveform."""
61+
return self._compute_linear_spectrogram(waveform)
62+
63+
def _compute_linear_spectrogram(self, waveform):
64+
"""Compute the linear-scale, magnitude spectrograms for the input waveform.
65+
66+
Args:
67+
waveform: a float32 audio tensor.
68+
Returns:
69+
a float 32 tensor with shape [len, num_bins]
70+
"""
71+
72+
# `stfts` is a complex64 Tensor representing the Short-time Fourier
73+
# Transform of each signal in `signals`. Its shape is
74+
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
75+
stfts = tf.contrib.signal.stft(
76+
waveform,
77+
frame_length=self.frame_length,
78+
frame_step=self.frame_step,
79+
fft_length=self.fft_length,
80+
window_fn=self.window_fn,
81+
pad_end=True)
82+
83+
# An energy spectrogram is the magnitude of the complex-valued STFT.
84+
# A float32 Tensor of shape [?, 257].
85+
magnitude_spectrograms = tf.abs(stfts)
86+
return magnitude_spectrograms
87+
88+
def _compute_mel_filterbank_features(self, waveform):
89+
"""Compute the mel filterbank features."""
90+
raise NotImplementedError("MFCC feature extraction not supported yet.")
91+
92+
93+
class TextFeaturizer(object):
94+
"""Extract text feature based on char-level granularity.
95+
96+
By looking up the vocabulary table, each input string (one line of transcript)
97+
will be converted to a sequence of integer indexes.
98+
"""
99+
100+
def __init__(self, vocab_file):
101+
lines = []
102+
with codecs.open(vocab_file, "r", "utf-8") as fin:
103+
lines.extend(fin.readlines())
104+
self.token_to_idx = {}
105+
self.idx_to_token = {}
106+
self.speech_labels = ""
107+
idx = 0
108+
for line in lines:
109+
line = line[:-1] # Strip the '\n' char.
110+
if line.startswith("#"):
111+
# Skip from reading comment line.
112+
continue
113+
self.token_to_idx[line] = idx
114+
self.idx_to_token[idx] = line
115+
self.speech_labels += line
116+
idx += 1
117+
118+
def featurize(self, text):
119+
"""Convert string to a list of integers."""
120+
tokens = list(text.strip().lower())
121+
feats = [self.token_to_idx[token] for token in tokens]
122+
return feats
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which
2+
# will be ignored by the parser.
3+
# begin of vocabulary
4+
a
5+
b
6+
c
7+
d
8+
e
9+
f
10+
g
11+
h
12+
i
13+
j
14+
k
15+
l
16+
m
17+
n
18+
o
19+
p
20+
q
21+
r
22+
s
23+
t
24+
u
25+
v
26+
w
27+
x
28+
y
29+
z
30+
'
31+
32+
-
33+
# end of vocabulary

0 commit comments

Comments
 (0)