|
| 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import print_function |
| 16 | +from utils.static_ps.reader_helper import get_reader, get_example_num, get_file_list, get_word_num |
| 17 | +from utils.static_ps.program_helper import get_model, get_strategy |
| 18 | +from utils.static_ps.common import YamlHelper, is_distributed_env |
| 19 | +import argparse |
| 20 | +import time |
| 21 | +import sys |
| 22 | +import paddle.distributed.fleet as fleet |
| 23 | +import paddle.distributed.fleet.base.role_maker as role_maker |
| 24 | +import paddle |
| 25 | +import os |
| 26 | +import warnings |
| 27 | +import logging |
| 28 | +import paddle.fluid as fluid |
| 29 | + |
| 30 | +__dir__ = os.path.dirname(os.path.abspath(__file__)) |
| 31 | +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) |
| 32 | + |
| 33 | +logging.basicConfig( |
| 34 | + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | + |
| 38 | +def parse_args(): |
| 39 | + parser = argparse.ArgumentParser("PaddleRec train script") |
| 40 | + parser.add_argument( |
| 41 | + '-m', |
| 42 | + '--config_yaml', |
| 43 | + type=str, |
| 44 | + required=True, |
| 45 | + help='config file path') |
| 46 | + args = parser.parse_args() |
| 47 | + args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml)) |
| 48 | + yaml_helper = YamlHelper() |
| 49 | + config = yaml_helper.load_yaml(args.config_yaml) |
| 50 | + config["yaml_path"] = args.config_yaml |
| 51 | + config["config_abs_dir"] = args.abs_dir |
| 52 | + yaml_helper.print_yaml(config) |
| 53 | + return config |
| 54 | + |
| 55 | + |
| 56 | +class Main(object): |
| 57 | + def __init__(self, config): |
| 58 | + self.metrics = {} |
| 59 | + self.config = config |
| 60 | + self.input_data = None |
| 61 | + self.reader = None |
| 62 | + self.exe = None |
| 63 | + self.train_result_dict = {} |
| 64 | + self.train_result_dict["speed"] = [] |
| 65 | + |
| 66 | + def run(self): |
| 67 | + fleet.init() |
| 68 | + self.network() |
| 69 | + if fleet.is_server(): |
| 70 | + self.run_server() |
| 71 | + elif fleet.is_worker(): |
| 72 | + self.run_online_worker() |
| 73 | + fleet.stop_worker() |
| 74 | + self.record_result() |
| 75 | + logger.info("Run Success, Exit.") |
| 76 | + |
| 77 | + def network(self): |
| 78 | + model = get_model(self.config) |
| 79 | + self.input_data = model.create_feeds() |
| 80 | + self.metrics = model.net(self.input_data) |
| 81 | + self.inference_target_var = model.inference_target_var |
| 82 | + logger.info("cpu_num: {}".format(os.getenv("CPU_NUM"))) |
| 83 | + model.create_optimizer(get_strategy(self.config)) |
| 84 | + |
| 85 | + def run_server(self): |
| 86 | + logger.info("Run Server Begin") |
| 87 | + fleet.init_server(config.get("runner.warmup_model_path")) |
| 88 | + fleet.run_server() |
| 89 | + |
| 90 | + def wait_and_prepare_dataset(self, day, pass_index): |
| 91 | + train_data_dir = self.config.get("runner.train_data_dir", []) |
| 92 | + |
| 93 | + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") |
| 94 | + dataset.set_use_var(self.input_data) |
| 95 | + dataset.set_batch_size(self.config.get('runner.train_batch_size')) |
| 96 | + dataset.set_thread(self.config.get('runner.train_thread_num')) |
| 97 | + |
| 98 | + # may you need define your dataset_filelist for day/pass_index |
| 99 | + filelist = [] |
| 100 | + for path in train_data_dir: |
| 101 | + filelist += [path + "/%s" % x for x in os.listdir(path)] |
| 102 | + |
| 103 | + dataset.set_filelist(filelist) |
| 104 | + dataset.set_pipe_command(self.config.get("runner.pipe_command")) |
| 105 | + dataset.load_into_memory() |
| 106 | + return dataset |
| 107 | + |
| 108 | + def run_online_worker(self): |
| 109 | + logger.info("Run Online Worker Begin") |
| 110 | + use_cuda = int(config.get("runner.use_gpu")) |
| 111 | + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() |
| 112 | + self.exe = paddle.static.Executor(place) |
| 113 | + |
| 114 | + with open("./{}_worker_main_program.prototxt".format( |
| 115 | + fleet.worker_index()), 'w+') as f: |
| 116 | + f.write(str(paddle.static.default_main_program())) |
| 117 | + with open("./{}_worker_startup_program.prototxt".format( |
| 118 | + fleet.worker_index()), 'w+') as f: |
| 119 | + f.write(str(paddle.static.default_startup_program())) |
| 120 | + |
| 121 | + self.exe.run(paddle.static.default_startup_program()) |
| 122 | + fleet.init_worker() |
| 123 | + |
| 124 | + save_model_path = self.config.get("runner.model_save_path") |
| 125 | + if save_model_path and (not os.path.exists(save_model_path)): |
| 126 | + os.makedirs(save_model_path) |
| 127 | + |
| 128 | + days = os.popen("echo -n " + self.config.get("runner.days")).read().split(" ") |
| 129 | + pass_per_day = int(self.config.get("runner.pass_per_day")) |
| 130 | + |
| 131 | + for day_index in range(len(days)): |
| 132 | + day = days[day_index] |
| 133 | + for pass_index in range(1, pass_per_day + 1): |
| 134 | + logger.info("Day: {} Pass: {} Begin.".format(day, pass_index)) |
| 135 | + |
| 136 | + prepare_data_start_time = time.time() |
| 137 | + dataset = self.wait_and_prepare_dataset(day, pass_index) |
| 138 | + prepare_data_end_time = time.time() |
| 139 | + logger.info( |
| 140 | + "Prepare Dataset Done, using time {} second.".format(prepare_data_end_time - prepare_data_start_time)) |
| 141 | + |
| 142 | + train_start_time = time.time() |
| 143 | + self.dataset_train_loop(dataset, day, pass_index) |
| 144 | + train_end_time = time.time() |
| 145 | + logger.info( |
| 146 | + "Train Dataset Done, using time {} second.".format(train_end_time - train_start_time)) |
| 147 | + |
| 148 | + model_dir = "{}/{}/{}".format(save_model_path, day, pass_index) |
| 149 | + |
| 150 | + if fleet.is_first_worker() and save_model_path and is_distributed_env(): |
| 151 | + fleet.save_inference_model( |
| 152 | + self.exe, model_dir, |
| 153 | + [feed.name for feed in self.input_data], |
| 154 | + self.inference_target_var, |
| 155 | + mode=2) |
| 156 | + |
| 157 | + if fleet.is_first_worker() and save_model_path and is_distributed_env(): |
| 158 | + fleet.save_inference_model( |
| 159 | + self.exe, model_dir, |
| 160 | + [feed.name for feed in self.input_data], |
| 161 | + self.inference_target_var, |
| 162 | + mode=0) |
| 163 | + |
| 164 | + def dataset_train_loop(self, cur_dataset, day, pass_index): |
| 165 | + logger.info("Day: {} Pass: {}, Running Dataset Begin.".format(day, pass_index)) |
| 166 | + fetch_info = [ |
| 167 | + "Day: {} Pass: {} Var {}".format(day, pass_index, var_name) |
| 168 | + for var_name in self.metrics |
| 169 | + ] |
| 170 | + fetch_vars = [var for _, var in self.metrics.items()] |
| 171 | + print_step = int(config.get("runner.print_interval")) |
| 172 | + self.exe.train_from_dataset( |
| 173 | + program=paddle.static.default_main_program(), |
| 174 | + dataset=cur_dataset, |
| 175 | + fetch_list=fetch_vars, |
| 176 | + fetch_info=fetch_info, |
| 177 | + print_period=print_step, |
| 178 | + debug=config.get("runner.dataset_debug")) |
| 179 | + cur_dataset.release_memory() |
| 180 | + |
| 181 | +if __name__ == "__main__": |
| 182 | + paddle.enable_static() |
| 183 | + config = parse_args() |
| 184 | + # os.environ["CPU_NUM"] = str(config.get("runner.thread_num")) |
| 185 | + benchmark_main = Main(config) |
| 186 | + benchmark_main.run() |
0 commit comments