Skip to content

Commit 5f05ce2

Browse files
LukeWoodTF Object Detection Team
authored andcommitted
Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.
PiperOrigin-RevId: 437338390
1 parent 7368d50 commit 5f05ce2

File tree

5 files changed

+46
-41
lines changed

5 files changed

+46
-41
lines changed

research/object_detection/inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222

2323
import tensorflow.compat.v1 as tf
24+
from tensorflow.compat.v1 import estimator as tf_estimator
2425
from object_detection.builders import dataset_builder
2526
from object_detection.builders import image_resizer_builder
2627
from object_detection.builders import model_builder
@@ -1114,7 +1115,7 @@ def _predict_input_fn(params=None):
11141115
true_image_shape = tf.expand_dims(
11151116
input_dict[fields.InputDataFields.true_image_shape], axis=0)
11161117

1117-
return tf.estimator.export.ServingInputReceiver(
1118+
return tf_estimator.export.ServingInputReceiver(
11181119
features={
11191120
fields.InputDataFields.image: images,
11201121
fields.InputDataFields.true_image_shape: true_image_shape},

research/object_detection/model_lib.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424

2525
import tensorflow.compat.v1 as tf
26+
from tensorflow.compat.v1 import estimator as tf_estimator
2627
import tensorflow.compat.v2 as tf2
2728
import tf_slim as slim
2829

@@ -465,7 +466,7 @@ def model_fn(features, labels, mode, params=None):
465466
"""
466467
params = params or {}
467468
total_loss, train_op, detections, export_outputs = None, None, None, None
468-
is_training = mode == tf.estimator.ModeKeys.TRAIN
469+
is_training = mode == tf_estimator.ModeKeys.TRAIN
469470

470471
# Make sure to set the Keras learning phase. True during training,
471472
# False for inference.
@@ -479,11 +480,11 @@ def model_fn(features, labels, mode, params=None):
479480
is_training=is_training, add_summaries=(not use_tpu))
480481
scaffold_fn = None
481482

482-
if mode == tf.estimator.ModeKeys.TRAIN:
483+
if mode == tf_estimator.ModeKeys.TRAIN:
483484
labels = unstack_batch(
484485
labels,
485486
unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
486-
elif mode == tf.estimator.ModeKeys.EVAL:
487+
elif mode == tf_estimator.ModeKeys.EVAL:
487488
# For evaling on train data, it is necessary to check whether groundtruth
488489
# must be unpadded.
489490
boxes_shape = (
@@ -493,7 +494,7 @@ def model_fn(features, labels, mode, params=None):
493494
labels = unstack_batch(
494495
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
495496

496-
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
497+
if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
497498
provide_groundtruth(detection_model, labels)
498499

499500
preprocessed_images = features[fields.InputDataFields.image]
@@ -514,7 +515,7 @@ def model_fn(features, labels, mode, params=None):
514515
def postprocess_wrapper(args):
515516
return detection_model.postprocess(args[0], args[1])
516517

517-
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
518+
if mode in (tf_estimator.ModeKeys.EVAL, tf_estimator.ModeKeys.PREDICT):
518519
if use_tpu and postprocess_on_cpu:
519520
detections = tf.tpu.outside_compilation(
520521
postprocess_wrapper,
@@ -525,7 +526,7 @@ def postprocess_wrapper(args):
525526
prediction_dict,
526527
features[fields.InputDataFields.true_image_shape]))
527528

528-
if mode == tf.estimator.ModeKeys.TRAIN:
529+
if mode == tf_estimator.ModeKeys.TRAIN:
529530
load_pretrained = hparams.load_pretrained if hparams else False
530531
if train_config.fine_tune_checkpoint and load_pretrained:
531532
if not train_config.fine_tune_checkpoint_type:
@@ -557,8 +558,8 @@ def tpu_scaffold():
557558
tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
558559
available_var_map)
559560

560-
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
561-
if (mode == tf.estimator.ModeKeys.EVAL and
561+
if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
562+
if (mode == tf_estimator.ModeKeys.EVAL and
562563
eval_config.use_dummy_loss_in_eval):
563564
total_loss = tf.constant(1.0)
564565
losses_dict = {'Loss/total_loss': total_loss}
@@ -590,7 +591,7 @@ def tpu_scaffold():
590591
training_optimizer, optimizer_summary_vars = optimizer_builder.build(
591592
train_config.optimizer)
592593

593-
if mode == tf.estimator.ModeKeys.TRAIN:
594+
if mode == tf_estimator.ModeKeys.TRAIN:
594595
if use_tpu:
595596
training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer)
596597

@@ -628,16 +629,16 @@ def tpu_scaffold():
628629
summaries=summaries,
629630
name='') # Preventing scope prefix on all variables.
630631

631-
if mode == tf.estimator.ModeKeys.PREDICT:
632+
if mode == tf_estimator.ModeKeys.PREDICT:
632633
exported_output = exporter_lib.add_output_tensor_nodes(detections)
633634
export_outputs = {
634635
tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
635-
tf.estimator.export.PredictOutput(exported_output)
636+
tf_estimator.export.PredictOutput(exported_output)
636637
}
637638

638639
eval_metric_ops = None
639640
scaffold = None
640-
if mode == tf.estimator.ModeKeys.EVAL:
641+
if mode == tf_estimator.ModeKeys.EVAL:
641642
class_agnostic = (
642643
fields.DetectionResultFields.detection_classes not in detections)
643644
groundtruth = _prepare_groundtruth_for_eval(
@@ -711,8 +712,8 @@ def tpu_scaffold():
711712
scaffold = tf.train.Scaffold(saver=saver)
712713

713714
# EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
714-
if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
715-
return tf.estimator.tpu.TPUEstimatorSpec(
715+
if use_tpu and mode != tf_estimator.ModeKeys.EVAL:
716+
return tf_estimator.tpu.TPUEstimatorSpec(
716717
mode=mode,
717718
scaffold_fn=scaffold_fn,
718719
predictions=detections,
@@ -730,7 +731,7 @@ def tpu_scaffold():
730731
save_relative_paths=True)
731732
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
732733
scaffold = tf.train.Scaffold(saver=saver)
733-
return tf.estimator.EstimatorSpec(
734+
return tf_estimator.EstimatorSpec(
734735
mode=mode,
735736
predictions=detections,
736737
loss=total_loss,
@@ -895,7 +896,7 @@ def create_estimator_and_inputs(run_config,
895896
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
896897
postprocess_on_cpu)
897898
if use_tpu_estimator:
898-
estimator = tf.estimator.tpu.TPUEstimator(
899+
estimator = tf_estimator.tpu.TPUEstimator(
899900
model_fn=model_fn,
900901
train_batch_size=train_config.batch_size,
901902
# For each core, only batch size 1 is supported for eval.
@@ -906,7 +907,7 @@ def create_estimator_and_inputs(run_config,
906907
eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU
907908
params=params if params else {})
908909
else:
909-
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
910+
estimator = tf_estimator.Estimator(model_fn=model_fn, config=run_config)
910911

911912
# Write the as-run pipeline config to disk.
912913
if run_config.is_chief and save_final_config:
@@ -951,7 +952,7 @@ def create_train_and_eval_specs(train_input_fn,
951952
True, the last `EvalSpec` in the list will correspond to training data. The
952953
rest EvalSpecs in the list are evaluation datas.
953954
"""
954-
train_spec = tf.estimator.TrainSpec(
955+
train_spec = tf_estimator.TrainSpec(
955956
input_fn=train_input_fn, max_steps=train_steps)
956957

957958
if eval_spec_names is None:
@@ -966,18 +967,18 @@ def create_train_and_eval_specs(train_input_fn,
966967
exporter_name = final_exporter_name
967968
else:
968969
exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
969-
exporter = tf.estimator.FinalExporter(
970+
exporter = tf_estimator.FinalExporter(
970971
name=exporter_name, serving_input_receiver_fn=predict_input_fn)
971972
eval_specs.append(
972-
tf.estimator.EvalSpec(
973+
tf_estimator.EvalSpec(
973974
name=eval_spec_name,
974975
input_fn=eval_input_fn,
975976
steps=None,
976977
exporters=exporter))
977978

978979
if eval_on_train_data:
979980
eval_specs.append(
980-
tf.estimator.EvalSpec(
981+
tf_estimator.EvalSpec(
981982
name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
982983

983984
return train_spec, eval_specs

research/object_detection/model_lib_tf1_test.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import unittest
2424
import numpy as np
2525
import tensorflow.compat.v1 as tf
26+
from tensorflow.compat.v1 import estimator as tf_estimator
2627

2728
from object_detection import inputs
2829
from object_detection import model_hparams
@@ -137,21 +138,21 @@ def _assert_model_fn_for_train_eval(self, configs, mode,
137138
inputs.create_train_input_fn(configs['train_config'],
138139
configs['train_input_config'],
139140
configs['model'])()).get_next()
140-
model_mode = tf.estimator.ModeKeys.TRAIN
141+
model_mode = tf_estimator.ModeKeys.TRAIN
141142
batch_size = train_config.batch_size
142143
elif mode == 'eval':
143144
features, labels = _make_initializable_iterator(
144145
inputs.create_eval_input_fn(configs['eval_config'],
145146
configs['eval_input_config'],
146147
configs['model'])()).get_next()
147-
model_mode = tf.estimator.ModeKeys.EVAL
148+
model_mode = tf_estimator.ModeKeys.EVAL
148149
batch_size = 1
149150
elif mode == 'eval_on_train':
150151
features, labels = _make_initializable_iterator(
151152
inputs.create_eval_input_fn(configs['eval_config'],
152153
configs['train_input_config'],
153154
configs['model'])()).get_next()
154-
model_mode = tf.estimator.ModeKeys.EVAL
155+
model_mode = tf_estimator.ModeKeys.EVAL
155156
batch_size = 1
156157

157158
detection_model_fn = functools.partial(
@@ -183,7 +184,7 @@ def _assert_model_fn_for_train_eval(self, configs, mode,
183184
if mode == 'eval':
184185
self.assertIn('Detections_Left_Groundtruth_Right/0',
185186
estimator_spec.eval_metric_ops)
186-
if model_mode == tf.estimator.ModeKeys.TRAIN:
187+
if model_mode == tf_estimator.ModeKeys.TRAIN:
187188
self.assertIsNotNone(estimator_spec.train_op)
188189
return estimator_spec
189190

@@ -202,7 +203,7 @@ def _assert_model_fn_for_predict(self, configs):
202203
hparams_overrides='load_pretrained=false')
203204

204205
model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
205-
estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT)
206+
estimator_spec = model_fn(features, None, tf_estimator.ModeKeys.PREDICT)
206207

207208
self.assertIsNone(estimator_spec.loss)
208209
self.assertIsNone(estimator_spec.train_op)
@@ -279,7 +280,7 @@ def test_model_fn_in_predict_mode(self):
279280

280281
def test_create_estimator_and_inputs(self):
281282
"""Tests that Estimator and input function are constructed correctly."""
282-
run_config = tf.estimator.RunConfig()
283+
run_config = tf_estimator.RunConfig()
283284
hparams = model_hparams.create_hparams(
284285
hparams_overrides='load_pretrained=false')
285286
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
@@ -291,15 +292,15 @@ def test_create_estimator_and_inputs(self):
291292
train_steps=train_steps)
292293
estimator = train_and_eval_dict['estimator']
293294
train_steps = train_and_eval_dict['train_steps']
294-
self.assertIsInstance(estimator, tf.estimator.Estimator)
295+
self.assertIsInstance(estimator, tf_estimator.Estimator)
295296
self.assertEqual(20, train_steps)
296297
self.assertIn('train_input_fn', train_and_eval_dict)
297298
self.assertIn('eval_input_fns', train_and_eval_dict)
298299
self.assertIn('eval_on_train_input_fn', train_and_eval_dict)
299300

300301
def test_create_estimator_and_inputs_sequence_example(self):
301302
"""Tests that Estimator and input function are constructed correctly."""
302-
run_config = tf.estimator.RunConfig()
303+
run_config = tf_estimator.RunConfig()
303304
hparams = model_hparams.create_hparams(
304305
hparams_overrides='load_pretrained=false')
305306
pipeline_config_path = get_pipeline_config_path(
@@ -312,15 +313,15 @@ def test_create_estimator_and_inputs_sequence_example(self):
312313
train_steps=train_steps)
313314
estimator = train_and_eval_dict['estimator']
314315
train_steps = train_and_eval_dict['train_steps']
315-
self.assertIsInstance(estimator, tf.estimator.Estimator)
316+
self.assertIsInstance(estimator, tf_estimator.Estimator)
316317
self.assertEqual(20, train_steps)
317318
self.assertIn('train_input_fn', train_and_eval_dict)
318319
self.assertIn('eval_input_fns', train_and_eval_dict)
319320
self.assertIn('eval_on_train_input_fn', train_and_eval_dict)
320321

321322
def test_create_estimator_with_default_train_eval_steps(self):
322323
"""Tests that number of train/eval defaults to config values."""
323-
run_config = tf.estimator.RunConfig()
324+
run_config = tf_estimator.RunConfig()
324325
hparams = model_hparams.create_hparams(
325326
hparams_overrides='load_pretrained=false')
326327
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
@@ -331,12 +332,12 @@ def test_create_estimator_with_default_train_eval_steps(self):
331332
estimator = train_and_eval_dict['estimator']
332333
train_steps = train_and_eval_dict['train_steps']
333334

334-
self.assertIsInstance(estimator, tf.estimator.Estimator)
335+
self.assertIsInstance(estimator, tf_estimator.Estimator)
335336
self.assertEqual(config_train_steps, train_steps)
336337

337338
def test_create_tpu_estimator_and_inputs(self):
338339
"""Tests that number of train/eval defaults to config values."""
339-
run_config = tf.estimator.tpu.RunConfig()
340+
run_config = tf_estimator.tpu.RunConfig()
340341
hparams = model_hparams.create_hparams(
341342
hparams_overrides='load_pretrained=false')
342343
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
@@ -350,12 +351,12 @@ def test_create_tpu_estimator_and_inputs(self):
350351
estimator = train_and_eval_dict['estimator']
351352
train_steps = train_and_eval_dict['train_steps']
352353

353-
self.assertIsInstance(estimator, tf.estimator.tpu.TPUEstimator)
354+
self.assertIsInstance(estimator, tf_estimator.tpu.TPUEstimator)
354355
self.assertEqual(20, train_steps)
355356

356357
def test_create_train_and_eval_specs(self):
357358
"""Tests that `TrainSpec` and `EvalSpec` is created correctly."""
358-
run_config = tf.estimator.RunConfig()
359+
run_config = tf_estimator.RunConfig()
359360
hparams = model_hparams.create_hparams(
360361
hparams_overrides='load_pretrained=false')
361362
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
@@ -390,7 +391,7 @@ def test_create_train_and_eval_specs(self):
390391

391392
def test_experiment(self):
392393
"""Tests that the `Experiment` object is constructed correctly."""
393-
run_config = tf.estimator.RunConfig()
394+
run_config = tf_estimator.RunConfig()
394395
hparams = model_hparams.create_hparams(
395396
hparams_overrides='load_pretrained=false')
396397
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)

research/object_detection/model_main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl import flags
2222

2323
import tensorflow.compat.v1 as tf
24+
from tensorflow.compat.v1 import estimator as tf_estimator
2425

2526
from object_detection import model_lib
2627

@@ -59,7 +60,7 @@
5960
def main(unused_argv):
6061
flags.mark_flag_as_required('model_dir')
6162
flags.mark_flag_as_required('pipeline_config_path')
62-
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
63+
config = tf_estimator.RunConfig(model_dir=FLAGS.model_dir)
6364

6465
train_and_eval_dict = model_lib.create_estimator_and_inputs(
6566
run_config=config,
@@ -101,7 +102,7 @@ def main(unused_argv):
101102
eval_on_train_data=False)
102103

103104
# Currently only a single Eval Spec is allowed.
104-
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
105+
tf_estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
105106

106107

107108
if __name__ == '__main__':

research/object_detection/model_tpu_main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from absl import flags
2626
import tensorflow.compat.v1 as tf
27+
from tensorflow.compat.v1 import estimator as tf_estimator
2728

2829

2930
from object_detection import model_lib
@@ -89,11 +90,11 @@ def main(unused_argv):
8990
tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
9091
tpu_grpc_url = tpu_cluster_resolver.get_master()
9192

92-
config = tf.estimator.tpu.RunConfig(
93+
config = tf_estimator.tpu.RunConfig(
9394
master=tpu_grpc_url,
9495
evaluation_master=tpu_grpc_url,
9596
model_dir=FLAGS.model_dir,
96-
tpu_config=tf.estimator.tpu.TPUConfig(
97+
tpu_config=tf_estimator.tpu.TPUConfig(
9798
iterations_per_loop=FLAGS.iterations_per_loop,
9899
num_shards=FLAGS.num_shards))
99100

0 commit comments

Comments
 (0)