@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
6161 Returns:
6262 the ctc_input_length after convolution layer.
6363 """
64- ctc_input_length = tf .to_float (tf .multiply (
65- input_length , ctc_time_steps ))
66- return tf .to_int32 (tf .floordiv (
67- ctc_input_length , tf .to_float (max_time_steps )))
68-
69-
70- def ctc_loss (label_length , ctc_input_length , labels , logits ):
71- """Computes the ctc loss for the current batch of predictions."""
72- label_length = tf .to_int32 (tf .squeeze (label_length ))
73- ctc_input_length = tf .to_int32 (tf .squeeze (ctc_input_length ))
74- sparse_labels = tf .to_int32 (
75- tf .keras .backend .ctc_label_dense_to_sparse (labels , label_length ))
76- y_pred = tf .log (tf .transpose (
77- logits , perm = [1 , 0 , 2 ]) + tf .keras .backend .epsilon ())
78-
79- return tf .expand_dims (
80- tf .nn .ctc_loss (labels = sparse_labels , inputs = y_pred ,
81- sequence_length = ctc_input_length ),
82- axis = 1 )
64+ ctc_input_length = tf .cast (tf .multiply (
65+ input_length , ctc_time_steps ), dtype = tf .float32 )
66+ return tf .cast (tf .math .floordiv (
67+ ctc_input_length , tf .cast (max_time_steps , dtype = tf .float32 )), dtype = tf .int32 )
8368
8469
8570def evaluate_model (estimator , speech_labels , entries , input_fn_eval ):
@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
123108 total_cer /= num_of_examples
124109 total_wer /= num_of_examples
125110
126- global_step = estimator .get_variable_value (tf .GraphKeys .GLOBAL_STEP )
111+ global_step = estimator .get_variable_value (tf .compat . v1 . GraphKeys .GLOBAL_STEP )
127112 eval_results = {
128113 _WER_KEY : total_wer ,
129114 _CER_KEY : total_cer ,
130- tf .GraphKeys .GLOBAL_STEP : global_step ,
115+ tf .compat . v1 . GraphKeys .GLOBAL_STEP : global_step ,
131116 }
132117
133118 return eval_results
@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
163148 logits = model (features , training = False )
164149 predictions = {
165150 "classes" : tf .argmax (logits , axis = 2 ),
166- "probabilities" : tf . nn . softmax ( logits ) ,
151+ "probabilities" : logits ,
167152 "logits" : logits
168153 }
169154 return tf .estimator .EstimatorSpec (
@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
172157
173158 # In training mode.
174159 logits = model (features , training = True )
175- probs = tf .nn .softmax (logits )
176160 ctc_input_length = compute_length_after_conv (
177- tf .shape (features )[1 ], tf .shape (probs )[1 ], input_length )
161+ tf .shape (features )[1 ], tf .shape (logits )[1 ], input_length )
178162 # Compute CTC loss
179- loss = tf .reduce_mean (ctc_loss (
180- label_length , ctc_input_length , labels , probs ))
163+ loss = tf .reduce_mean (tf . keras . backend . ctc_batch_cost (
164+ labels , logits , ctc_input_length , label_length ))
181165
182- optimizer = tf .train .AdamOptimizer (learning_rate = flags_obj .learning_rate )
183- global_step = tf .train .get_or_create_global_step ()
166+ optimizer = tf .compat . v1 . train .AdamOptimizer (learning_rate = flags_obj .learning_rate )
167+ global_step = tf .compat . v1 . train .get_or_create_global_step ()
184168 minimize_op = optimizer .minimize (loss , global_step = global_step )
185- update_ops = tf .get_collection (tf .GraphKeys .UPDATE_OPS )
169+ update_ops = tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .UPDATE_OPS )
186170 # Create the train_op that groups both minimize_ops and update_ops
187171 train_op = tf .group (minimize_op , update_ops )
188172
@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
239223
240224def run_deep_speech (_ ):
241225 """Run deep speech training and eval loop."""
242- tf .set_random_seed (flags_obj .seed )
226+ tf .compat . v1 . set_random_seed (flags_obj .seed )
243227 # Data preprocessing
244- tf .logging .info ("Data preprocessing..." )
228+ tf .compat . v1 . logging .info ("Data preprocessing..." )
245229 train_speech_dataset = generate_dataset (flags_obj .train_data_dir )
246230 eval_speech_dataset = generate_dataset (flags_obj .eval_data_dir )
247231
@@ -287,7 +271,7 @@ def input_fn_eval():
287271 total_training_cycle = (flags_obj .train_epochs //
288272 flags_obj .epochs_between_evals )
289273 for cycle_index in range (total_training_cycle ):
290- tf .logging .info ("Starting a training cycle: %d/%d" ,
274+ tf .compat . v1 . logging .info ("Starting a training cycle: %d/%d" ,
291275 cycle_index + 1 , total_training_cycle )
292276
293277 # Perform batch_wise dataset shuffling
@@ -298,15 +282,15 @@ def input_fn_eval():
298282 estimator .train (input_fn = input_fn_train )
299283
300284 # Evaluation
301- tf .logging .info ("Starting to evaluate..." )
285+ tf .compat . v1 . logging .info ("Starting to evaluate..." )
302286
303287 eval_results = evaluate_model (
304288 estimator , eval_speech_dataset .speech_labels ,
305289 eval_speech_dataset .entries , input_fn_eval )
306290
307291 # Log the WER and CER results.
308292 benchmark_logger .log_evaluation_result (eval_results )
309- tf .logging .info (
293+ tf .compat . v1 . logging .info (
310294 "Iteration {}: WER = {:.2f}, CER = {:.2f}" .format (
311295 cycle_index + 1 , eval_results [_WER_KEY ], eval_results [_CER_KEY ]))
312296
@@ -425,7 +409,7 @@ def main(_):
425409
426410
427411if __name__ == "__main__" :
428- tf .logging .set_verbosity (tf .logging .INFO )
412+ tf .compat . v1 . logging .set_verbosity (tf . compat . v1 .logging .INFO )
429413 define_deep_speech_flags ()
430414 flags_obj = flags .FLAGS
431415 absl_app .run (main )
0 commit comments