Skip to content

Commit 3311242

Browse files
authored
1. update to tf2.x for deep_speech (#8696)
Update to TF 2 for deep_speech
1 parent ccf7da9 commit 3311242

File tree

4 files changed

+54
-78
lines changed

4 files changed

+54
-78
lines changed

Diff for: research/deep_speech/data/dataset.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __init__(self, audio_config, data_path, vocab_file_path, sortagrad):
7171
"""
7272

7373
self.audio_config = audio_config
74-
assert tf.gfile.Exists(data_path)
75-
assert tf.gfile.Exists(vocab_file_path)
74+
assert tf.io.gfile.exists(data_path)
75+
assert tf.io.gfile.exists(vocab_file_path)
7676
self.data_path = data_path
7777
self.vocab_file_path = vocab_file_path
7878
self.sortagrad = sortagrad
@@ -125,8 +125,8 @@ def _preprocess_data(file_path):
125125
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
126126
file_size.
127127
"""
128-
tf.logging.info("Loading data set {}".format(file_path))
129-
with tf.gfile.Open(file_path, "r") as f:
128+
tf.compat.v1.logging.info("Loading data set {}".format(file_path))
129+
with tf.io.gfile.GFile(file_path, "r") as f:
130130
lines = f.read().splitlines()
131131
# Skip the csv header in lines[0].
132132
lines = lines[1:]

Diff for: research/deep_speech/data/download.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def download_and_extract(directory, url):
5959
url: the url to download the data file.
6060
"""
6161

62-
if not tf.gfile.Exists(directory):
63-
tf.gfile.MakeDirs(directory)
62+
if not tf.io.gfile.exists(directory):
63+
tf.io.gfile.makedirs(directory)
6464

6565
_, tar_filepath = tempfile.mkstemp(suffix=".tar.gz")
6666

6767
try:
68-
tf.logging.info("Downloading %s to %s" % (url, tar_filepath))
68+
tf.compat.v1.logging.info("Downloading %s to %s" % (url, tar_filepath))
6969

7070
def _progress(count, block_size, total_size):
7171
sys.stdout.write("\r>> Downloading {} {:.1f}%".format(
@@ -75,12 +75,12 @@ def _progress(count, block_size, total_size):
7575
urllib.request.urlretrieve(url, tar_filepath, _progress)
7676
print()
7777
statinfo = os.stat(tar_filepath)
78-
tf.logging.info(
78+
tf.compat.v1.logging.info(
7979
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
8080
with tarfile.open(tar_filepath, "r") as tar:
8181
tar.extractall(directory)
8282
finally:
83-
tf.gfile.Remove(tar_filepath)
83+
tf.io.gfile.remove(tar_filepath)
8484

8585

8686
def convert_audio_and_split_transcript(input_dir, source_name, target_name,
@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
112112
output_file: the name of the newly generated csv file. e.g. test-clean.csv
113113
"""
114114

115-
tf.logging.info("Preprocessing audio and transcript for %s" % source_name)
115+
tf.compat.v1.logging.info("Preprocessing audio and transcript for %s" % source_name)
116116
source_dir = os.path.join(input_dir, source_name)
117117
target_dir = os.path.join(input_dir, target_name)
118118

119-
if not tf.gfile.Exists(target_dir):
120-
tf.gfile.MakeDirs(target_dir)
119+
if not tf.io.gfile.exists(target_dir):
120+
tf.io.gfile.makedirs(target_dir)
121121

122122
files = []
123123
tfm = Transformer()
124124
# Convert all FLAC file into WAV format. At the same time, generate the csv
125125
# file.
126-
for root, _, filenames in tf.gfile.Walk(source_dir):
126+
for root, _, filenames in tf.io.gfile.walk(source_dir):
127127
for filename in fnmatch.filter(filenames, "*.trans.txt"):
128128
trans_file = os.path.join(root, filename)
129129
with codecs.open(trans_file, "r", "utf-8") as fin:
@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
137137
# Convert FLAC to WAV.
138138
flac_file = os.path.join(root, seqid + ".flac")
139139
wav_file = os.path.join(target_dir, seqid + ".wav")
140-
if not tf.gfile.Exists(wav_file):
140+
if not tf.io.gfile.exists(wav_file):
141141
tfm.build(flac_file, wav_file)
142142
wav_filesize = os.path.getsize(wav_file)
143143

@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
149149
df = pandas.DataFrame(
150150
data=files, columns=["wav_filename", "wav_filesize", "transcript"])
151151
df.to_csv(csv_file_path, index=False, sep="\t")
152-
tf.logging.info("Successfully generated csv file {}".format(csv_file_path))
152+
tf.compat.v1.logging.info("Successfully generated csv file {}".format(csv_file_path))
153153

154154

155155
def download_and_process_datasets(directory, datasets):
@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets):
160160
datasets: list of dataset names that will be downloaded and processed.
161161
"""
162162

163-
tf.logging.info("Preparing LibriSpeech dataset: {}".format(
163+
tf.compat.v1.logging.info("Preparing LibriSpeech dataset: {}".format(
164164
",".join(datasets)))
165165
for dataset in datasets:
166-
tf.logging.info("Preparing dataset %s", dataset)
166+
tf.compat.v1.logging.info("Preparing dataset %s", dataset)
167167
dataset_dir = os.path.join(directory, dataset)
168168
download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset])
169169
convert_audio_and_split_transcript(
@@ -185,8 +185,8 @@ def define_data_download_flags():
185185

186186

187187
def main(_):
188-
if not tf.gfile.Exists(FLAGS.data_dir):
189-
tf.gfile.MakeDirs(FLAGS.data_dir)
188+
if not tf.io.gfile.exists(FLAGS.data_dir):
189+
tf.io.gfile.makedirs(FLAGS.data_dir)
190190

191191
if FLAGS.train_only:
192192
download_and_process_datasets(
@@ -202,7 +202,7 @@ def main(_):
202202

203203

204204
if __name__ == "__main__":
205-
tf.logging.set_verbosity(tf.logging.INFO)
205+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
206206
define_data_download_flags()
207207
FLAGS = absl_flags.FLAGS
208208
absl_app.run(main)

Diff for: research/deep_speech/deep_speech.py

+19-35
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
6161
Returns:
6262
the ctc_input_length after convolution layer.
6363
"""
64-
ctc_input_length = tf.to_float(tf.multiply(
65-
input_length, ctc_time_steps))
66-
return tf.to_int32(tf.floordiv(
67-
ctc_input_length, tf.to_float(max_time_steps)))
68-
69-
70-
def ctc_loss(label_length, ctc_input_length, labels, logits):
71-
"""Computes the ctc loss for the current batch of predictions."""
72-
label_length = tf.to_int32(tf.squeeze(label_length))
73-
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length))
74-
sparse_labels = tf.to_int32(
75-
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length))
76-
y_pred = tf.log(tf.transpose(
77-
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon())
78-
79-
return tf.expand_dims(
80-
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred,
81-
sequence_length=ctc_input_length),
82-
axis=1)
64+
ctc_input_length = tf.cast(tf.multiply(
65+
input_length, ctc_time_steps), dtype=tf.float32)
66+
return tf.cast(tf.math.floordiv(
67+
ctc_input_length, tf.cast(max_time_steps, dtype=tf.float32)), dtype=tf.int32)
8368

8469

8570
def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
123108
total_cer /= num_of_examples
124109
total_wer /= num_of_examples
125110

126-
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)
111+
global_step = estimator.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP)
127112
eval_results = {
128113
_WER_KEY: total_wer,
129114
_CER_KEY: total_cer,
130-
tf.GraphKeys.GLOBAL_STEP: global_step,
115+
tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step,
131116
}
132117

133118
return eval_results
@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
163148
logits = model(features, training=False)
164149
predictions = {
165150
"classes": tf.argmax(logits, axis=2),
166-
"probabilities": tf.nn.softmax(logits),
151+
"probabilities": logits,
167152
"logits": logits
168153
}
169154
return tf.estimator.EstimatorSpec(
@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
172157

173158
# In training mode.
174159
logits = model(features, training=True)
175-
probs = tf.nn.softmax(logits)
176160
ctc_input_length = compute_length_after_conv(
177-
tf.shape(features)[1], tf.shape(probs)[1], input_length)
161+
tf.shape(features)[1], tf.shape(logits)[1], input_length)
178162
# Compute CTC loss
179-
loss = tf.reduce_mean(ctc_loss(
180-
label_length, ctc_input_length, labels, probs))
163+
loss = tf.reduce_mean(tf.keras.backend.ctc_batch_cost(
164+
labels, logits, ctc_input_length, label_length))
181165

182-
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
183-
global_step = tf.train.get_or_create_global_step()
166+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
167+
global_step = tf.compat.v1.train.get_or_create_global_step()
184168
minimize_op = optimizer.minimize(loss, global_step=global_step)
185-
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
169+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
186170
# Create the train_op that groups both minimize_ops and update_ops
187171
train_op = tf.group(minimize_op, update_ops)
188172

@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
239223

240224
def run_deep_speech(_):
241225
"""Run deep speech training and eval loop."""
242-
tf.set_random_seed(flags_obj.seed)
226+
tf.compat.v1.set_random_seed(flags_obj.seed)
243227
# Data preprocessing
244-
tf.logging.info("Data preprocessing...")
228+
tf.compat.v1.logging.info("Data preprocessing...")
245229
train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
246230
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
247231

@@ -287,7 +271,7 @@ def input_fn_eval():
287271
total_training_cycle = (flags_obj.train_epochs //
288272
flags_obj.epochs_between_evals)
289273
for cycle_index in range(total_training_cycle):
290-
tf.logging.info("Starting a training cycle: %d/%d",
274+
tf.compat.v1.logging.info("Starting a training cycle: %d/%d",
291275
cycle_index + 1, total_training_cycle)
292276

293277
# Perform batch_wise dataset shuffling
@@ -298,15 +282,15 @@ def input_fn_eval():
298282
estimator.train(input_fn=input_fn_train)
299283

300284
# Evaluation
301-
tf.logging.info("Starting to evaluate...")
285+
tf.compat.v1.logging.info("Starting to evaluate...")
302286

303287
eval_results = evaluate_model(
304288
estimator, eval_speech_dataset.speech_labels,
305289
eval_speech_dataset.entries, input_fn_eval)
306290

307291
# Log the WER and CER results.
308292
benchmark_logger.log_evaluation_result(eval_results)
309-
tf.logging.info(
293+
tf.compat.v1.logging.info(
310294
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
311295
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
312296

@@ -425,7 +409,7 @@ def main(_):
425409

426410

427411
if __name__ == "__main__":
428-
tf.logging.set_verbosity(tf.logging.INFO)
412+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
429413
define_deep_speech_flags()
430414
flags_obj = flags.FLAGS
431415
absl_app.run(main)

Diff for: research/deep_speech/deep_speech_model.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
# Supported rnn cells.
2424
SUPPORTED_RNNS = {
25-
"lstm": tf.contrib.rnn.BasicLSTMCell,
26-
"rnn": tf.contrib.rnn.RNNCell,
27-
"gru": tf.contrib.rnn.GRUCell,
25+
"lstm": tf.keras.layers.LSTMCell,
26+
"rnn": tf.keras.layers.SimpleRNNCell,
27+
"gru": tf.keras.layers.GRUCell,
2828
}
2929

3030
# Parameters for batch normalization.
@@ -53,9 +53,8 @@ def batch_norm(inputs, training):
5353
Returns:
5454
tensor output from batch norm layer.
5555
"""
56-
return tf.layers.batch_normalization(
57-
inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
58-
fused=True, training=training)
56+
return tf.keras.layers.BatchNormalization(
57+
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON)(inputs, training=training)
5958

6059

6160
def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
8180
inputs = tf.pad(
8281
inputs,
8382
[[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])
84-
inputs = tf.layers.conv2d(
85-
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
83+
inputs = tf.keras.layers.Conv2D(
84+
filters=filters, kernel_size=kernel_size, strides=strides,
8685
padding="valid", use_bias=False, activation=tf.nn.relu6,
87-
name="cnn_{}".format(layer_id))
86+
name="cnn_{}".format(layer_id))(inputs)
8887
return batch_norm(inputs, training)
8988

9089

@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
109108
if is_batch_norm:
110109
inputs = batch_norm(inputs, training)
111110

112-
# Construct forward/backward RNN cells.
113-
fw_cell = rnn_cell(num_units=rnn_hidden_size,
114-
name="rnn_fw_{}".format(layer_id))
115-
bw_cell = rnn_cell(num_units=rnn_hidden_size,
116-
name="rnn_bw_{}".format(layer_id))
117-
118111
if is_bidirectional:
119-
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
120-
cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32,
121-
swap_memory=True)
122-
rnn_outputs = tf.concat(outputs, -1)
112+
rnn_outputs = tf.keras.layers.Bidirectional(
113+
tf.keras.layers.RNN(rnn_cell(rnn_hidden_size),
114+
return_sequences=True))(inputs)
123115
else:
124-
rnn_outputs = tf.nn.dynamic_rnn(
125-
fw_cell, inputs, dtype=tf.float32, swap_memory=True)
116+
rnn_outputs = tf.keras.layers.RNN(
117+
rnn_cell(rnn_hidden_size), return_sequences=True)(inputs)
126118

127119
return rnn_outputs
128120

129-
130121
class DeepSpeech2(object):
131122
"""Define DeepSpeech2 model."""
132123

@@ -179,7 +170,8 @@ def __call__(self, inputs, training):
179170

180171
# FC layer with batch norm.
181172
inputs = batch_norm(inputs, training)
182-
logits = tf.layers.dense(inputs, self.num_classes, use_bias=self.use_bias)
173+
logits = tf.keras.layers.Dense(
174+
self.num_classes, use_bias=self.use_bias, activation="softmax")(inputs)
183175

184176
return logits
185177

0 commit comments

Comments
 (0)