Skip to content

Commit 432a448

Browse files
saberkuntensorflower-gardener
authored andcommitted
Move files to core/ and common/
PiperOrigin-RevId: 326586473
1 parent a003b7c commit 432a448

File tree

6 files changed

+457
-0
lines changed

6 files changed

+457
-0
lines changed

official/common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

official/common/flags.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Lint as: python3
2+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""The central place to define flags."""
17+
18+
from absl import flags
19+
20+
21+
def define_flags():
22+
"""Defines flags."""
23+
flags.DEFINE_string(
24+
'experiment', default=None, help='The experiment type registered.')
25+
26+
flags.DEFINE_enum(
27+
'mode',
28+
default=None,
29+
enum_values=['train', 'eval', 'train_and_eval',
30+
'continuous_eval', 'continuous_train_and_eval'],
31+
help='Mode to run: `train`, `eval`, `train_and_eval`, '
32+
'`continuous_eval`, and `continuous_train_and_eval`.')
33+
34+
flags.DEFINE_string(
35+
'model_dir',
36+
default=None,
37+
help='The directory where the model and training/evaluation summaries'
38+
'are stored.')
39+
40+
flags.DEFINE_multi_string(
41+
'config_file',
42+
default=None,
43+
help='YAML/JSON files which specifies overrides. The override order '
44+
'follows the order of args. Note that each file '
45+
'can be used as an override template to override the default parameters '
46+
'specified in Python. If the same parameter is specified in both '
47+
'`--config_file` and `--params_override`, `config_file` will be used '
48+
'first, followed by params_override.')
49+
50+
flags.DEFINE_string(
51+
'params_override',
52+
default=None,
53+
help='a YAML/JSON string or a YAML file which specifies additional '
54+
'overrides over the default parameters and those specified in '
55+
'`--config_file`. Note that this is supposed to be used only to override '
56+
'the model parameters, but not the parameters like TPU specific flags. '
57+
'One canonical use case of `--config_file` and `--params_override` is '
58+
'users first define a template config file using `--config_file`, then '
59+
'use `--params_override` to adjust the minimal set of tuning parameters, '
60+
'for example setting up different `train_batch_size`. The final override '
61+
'order of parameters: default_model_params --> params from config_file '
62+
'--> params in params_override. See also the help message of '
63+
'`--config_file`.')
64+
65+
flags.DEFINE_multi_string(
66+
'gin_file', default=None, help='List of paths to the config files.')
67+
68+
flags.DEFINE_multi_string(
69+
'gin_params',
70+
default=None,
71+
help='Newline separated list of Gin parameter bindings.')
72+
73+
flags.DEFINE_string(
74+
'tpu', default=None,
75+
help='The Cloud TPU to use for training. This should be either the name '
76+
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
77+
'url.')

official/common/registry_imports.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""All necessary imports for registration."""
16+
17+
# pylint: disable=unused-import
18+
from official.nlp import tasks
19+
from official.utils.testing import mock_task

official/core/train_lib.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Lint as: python3
2+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""TFM common training driver library."""
17+
18+
import os
19+
from typing import Any, Mapping
20+
21+
# Import libraries
22+
from absl import logging
23+
import orbit
24+
import tensorflow as tf
25+
26+
from official.common import train_utils
27+
from official.core import base_task
28+
from official.modeling.hyperparams import config_definitions
29+
30+
31+
def run_experiment(distribution_strategy: tf.distribute.Strategy,
32+
task: base_task.Task,
33+
mode: str,
34+
params: config_definitions.ExperimentConfig,
35+
model_dir: str,
36+
run_post_eval: bool = False,
37+
save_summary: bool = True) -> Mapping[str, Any]:
38+
"""Runs train/eval configured by the experiment params.
39+
40+
Args:
41+
distribution_strategy: A distribution distribution_strategy.
42+
task: A Task instance.
43+
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
44+
or 'continuous_eval'.
45+
params: ExperimentConfig instance.
46+
model_dir: A 'str', a path to store model checkpoints and summaries.
47+
run_post_eval: Whether to run post eval once after training, metrics logs
48+
are returned.
49+
save_summary: Whether to save train and validation summary.
50+
51+
Returns:
52+
eval logs: returns eval metrics logs when run_post_eval is set to True,
53+
othewise, returns {}.
54+
"""
55+
56+
with distribution_strategy.scope():
57+
trainer = train_utils.create_trainer(
58+
params,
59+
task,
60+
model_dir,
61+
train='train' in mode,
62+
evaluate=('eval' in mode) or run_post_eval)
63+
64+
if trainer.checkpoint:
65+
checkpoint_manager = tf.train.CheckpointManager(
66+
trainer.checkpoint,
67+
directory=model_dir,
68+
max_to_keep=params.trainer.max_to_keep,
69+
step_counter=trainer.global_step,
70+
checkpoint_interval=params.trainer.checkpoint_interval,
71+
init_fn=trainer.initialize)
72+
else:
73+
checkpoint_manager = None
74+
75+
controller = orbit.Controller(
76+
distribution_strategy,
77+
trainer=trainer if 'train' in mode else None,
78+
evaluator=trainer,
79+
global_step=trainer.global_step,
80+
steps_per_loop=params.trainer.steps_per_loop,
81+
checkpoint_manager=checkpoint_manager,
82+
summary_dir=os.path.join(model_dir, 'train') if (
83+
save_summary) else None,
84+
eval_summary_dir=os.path.join(model_dir, 'validation') if (
85+
save_summary) else None,
86+
summary_interval=params.trainer.summary_interval if (
87+
save_summary) else None)
88+
89+
logging.info('Starts to execute mode: %s', mode)
90+
with distribution_strategy.scope():
91+
if mode == 'train':
92+
controller.train(steps=params.trainer.train_steps)
93+
elif mode == 'train_and_eval':
94+
controller.train_and_evaluate(
95+
train_steps=params.trainer.train_steps,
96+
eval_steps=params.trainer.validation_steps,
97+
eval_interval=params.trainer.validation_interval)
98+
elif mode == 'eval':
99+
controller.evaluate(steps=params.trainer.validation_steps)
100+
elif mode == 'continuous_eval':
101+
controller.evaluate_continuously(
102+
steps=params.trainer.validation_steps,
103+
timeout=params.trainer.continuous_eval_timeout)
104+
else:
105+
raise NotImplementedError('The mode is not implemented: %s' % mode)
106+
107+
if run_post_eval:
108+
with distribution_strategy.scope():
109+
return trainer.evaluate(
110+
tf.convert_to_tensor(params.trainer.validation_steps))
111+
else:
112+
return {}

official/core/train_lib_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Lint as: python3
2+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Tests for train_ctl_lib."""
17+
import json
18+
import os
19+
20+
from absl import flags
21+
from absl.testing import flagsaver
22+
from absl.testing import parameterized
23+
import tensorflow as tf
24+
25+
from tensorflow.python.distribute import combinations
26+
from tensorflow.python.distribute import strategy_combinations
27+
from official.common import flags as tfm_flags
28+
# pylint: disable=unused-import
29+
from official.common import registry_imports
30+
# pylint: enable=unused-import
31+
from official.core import task_factory
32+
from official.core import train_lib
33+
from official.core import train_utils
34+
35+
FLAGS = flags.FLAGS
36+
37+
tfm_flags.define_flags()
38+
39+
40+
class TrainTest(tf.test.TestCase, parameterized.TestCase):
41+
42+
def setUp(self):
43+
super(TrainTest, self).setUp()
44+
self._test_config = {
45+
'trainer': {
46+
'checkpoint_interval': 10,
47+
'steps_per_loop': 10,
48+
'summary_interval': 10,
49+
'train_steps': 10,
50+
'validation_steps': 5,
51+
'validation_interval': 10,
52+
'optimizer_config': {
53+
'optimizer': {
54+
'type': 'sgd',
55+
},
56+
'learning_rate': {
57+
'type': 'constant'
58+
}
59+
}
60+
},
61+
}
62+
63+
@combinations.generate(
64+
combinations.combine(
65+
distribution_strategy=[
66+
strategy_combinations.default_strategy,
67+
strategy_combinations.tpu_strategy,
68+
strategy_combinations.one_device_strategy_gpu,
69+
],
70+
mode='eager',
71+
flag_mode=['train', 'eval', 'train_and_eval'],
72+
run_post_eval=[True, False]))
73+
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
74+
model_dir = self.get_temp_dir()
75+
flags_dict = dict(
76+
experiment='mock',
77+
mode=flag_mode,
78+
model_dir=model_dir,
79+
params_override=json.dumps(self._test_config))
80+
with flagsaver.flagsaver(**flags_dict):
81+
params = train_utils.parse_configuration(flags.FLAGS)
82+
train_utils.serialize_config(params, model_dir)
83+
with distribution_strategy.scope():
84+
task = task_factory.get_task(params.task, logging_dir=model_dir)
85+
86+
logs = train_lib.run_experiment(
87+
distribution_strategy=distribution_strategy,
88+
task=task,
89+
mode=flag_mode,
90+
params=params,
91+
model_dir=model_dir,
92+
run_post_eval=run_post_eval)
93+
94+
if run_post_eval:
95+
self.assertNotEmpty(logs)
96+
else:
97+
self.assertEmpty(logs)
98+
self.assertNotEmpty(
99+
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
100+
if flag_mode != 'eval':
101+
self.assertNotEmpty(
102+
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
103+
104+
105+
if __name__ == '__main__':
106+
tf.test.main()

0 commit comments

Comments
 (0)