Skip to content

Commit e2e21d2

Browse files
authored
Merge pull request #373 from 123malin/online_trainer
test=develop, add static_ps_online_trainer
2 parents 26dd5b9 + c97d796 commit e2e21d2

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

models/rank/dnn/online.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
hyper_parameters:
16+
optimizer:
17+
class: Adam
18+
learning_rate: 0.0001
19+
adam_lazy_mode: True
20+
sparse_inputs_slots: 27
21+
sparse_feature_number: 1000001
22+
sparse_feature_dim: 10
23+
dense_input_dim: 13
24+
fc_sizes: [400, 400, 400]
25+
26+
runner:
27+
train_data_dir: ["data/sample_data/train/"]
28+
days: "{20191225..20191227}"
29+
pass_per_day: 24
30+
31+
train_batch_size: 12
32+
train_thread_num: 16
33+
geo_step: 400
34+
sync_mode: "async" # sync / async /geo / heter
35+
36+
pipe_command: "python benchmark_reader.py"
37+
print_interval: 100
38+
39+
use_gpu: 0
40+
41+
model_path: "static_model.py"
42+
dataset_debug: False
43+
model_save_path: "model"
44+
45+
# knock-in and knock-out
46+
# create_num_threshold: 1 # knock-in
47+
# max_keep_days: 60 # knock-out
48+
49+

tools/static_ps_online_trainer.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
from utils.static_ps.reader_helper import get_reader, get_example_num, get_file_list, get_word_num
17+
from utils.static_ps.program_helper import get_model, get_strategy
18+
from utils.static_ps.common import YamlHelper, is_distributed_env
19+
import argparse
20+
import time
21+
import sys
22+
import paddle.distributed.fleet as fleet
23+
import paddle.distributed.fleet.base.role_maker as role_maker
24+
import paddle
25+
import os
26+
import warnings
27+
import logging
28+
import paddle.fluid as fluid
29+
30+
__dir__ = os.path.dirname(os.path.abspath(__file__))
31+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
32+
33+
logging.basicConfig(
34+
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
35+
logger = logging.getLogger(__name__)
36+
37+
38+
def parse_args():
39+
parser = argparse.ArgumentParser("PaddleRec train script")
40+
parser.add_argument(
41+
'-m',
42+
'--config_yaml',
43+
type=str,
44+
required=True,
45+
help='config file path')
46+
args = parser.parse_args()
47+
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
48+
yaml_helper = YamlHelper()
49+
config = yaml_helper.load_yaml(args.config_yaml)
50+
config["yaml_path"] = args.config_yaml
51+
config["config_abs_dir"] = args.abs_dir
52+
yaml_helper.print_yaml(config)
53+
return config
54+
55+
56+
class Main(object):
57+
def __init__(self, config):
58+
self.metrics = {}
59+
self.config = config
60+
self.input_data = None
61+
self.reader = None
62+
self.exe = None
63+
self.train_result_dict = {}
64+
self.train_result_dict["speed"] = []
65+
66+
def run(self):
67+
fleet.init()
68+
self.network()
69+
if fleet.is_server():
70+
self.run_server()
71+
elif fleet.is_worker():
72+
self.run_online_worker()
73+
fleet.stop_worker()
74+
self.record_result()
75+
logger.info("Run Success, Exit.")
76+
77+
def network(self):
78+
model = get_model(self.config)
79+
self.input_data = model.create_feeds()
80+
self.metrics = model.net(self.input_data)
81+
self.inference_target_var = model.inference_target_var
82+
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
83+
model.create_optimizer(get_strategy(self.config))
84+
85+
def run_server(self):
86+
logger.info("Run Server Begin")
87+
fleet.init_server(config.get("runner.warmup_model_path"))
88+
fleet.run_server()
89+
90+
def wait_and_prepare_dataset(self, day, pass_index):
91+
train_data_dir = self.config.get("runner.train_data_dir", [])
92+
93+
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
94+
dataset.set_use_var(self.input_data)
95+
dataset.set_batch_size(self.config.get('runner.train_batch_size'))
96+
dataset.set_thread(self.config.get('runner.train_thread_num'))
97+
98+
# may you need define your dataset_filelist for day/pass_index
99+
filelist = []
100+
for path in train_data_dir:
101+
filelist += [path + "/%s" % x for x in os.listdir(path)]
102+
103+
dataset.set_filelist(filelist)
104+
dataset.set_pipe_command(self.config.get("runner.pipe_command"))
105+
dataset.load_into_memory()
106+
return dataset
107+
108+
def run_online_worker(self):
109+
logger.info("Run Online Worker Begin")
110+
use_cuda = int(config.get("runner.use_gpu"))
111+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
112+
self.exe = paddle.static.Executor(place)
113+
114+
with open("./{}_worker_main_program.prototxt".format(
115+
fleet.worker_index()), 'w+') as f:
116+
f.write(str(paddle.static.default_main_program()))
117+
with open("./{}_worker_startup_program.prototxt".format(
118+
fleet.worker_index()), 'w+') as f:
119+
f.write(str(paddle.static.default_startup_program()))
120+
121+
self.exe.run(paddle.static.default_startup_program())
122+
fleet.init_worker()
123+
124+
save_model_path = self.config.get("runner.model_save_path")
125+
if save_model_path and (not os.path.exists(save_model_path)):
126+
os.makedirs(save_model_path)
127+
128+
days = os.popen("echo -n " + self.config.get("runner.days")).read().split(" ")
129+
pass_per_day = int(self.config.get("runner.pass_per_day"))
130+
131+
for day_index in range(len(days)):
132+
day = days[day_index]
133+
for pass_index in range(1, pass_per_day + 1):
134+
logger.info("Day: {} Pass: {} Begin.".format(day, pass_index))
135+
136+
prepare_data_start_time = time.time()
137+
dataset = self.wait_and_prepare_dataset(day, pass_index)
138+
prepare_data_end_time = time.time()
139+
logger.info(
140+
"Prepare Dataset Done, using time {} second.".format(prepare_data_end_time - prepare_data_start_time))
141+
142+
train_start_time = time.time()
143+
self.dataset_train_loop(dataset, day, pass_index)
144+
train_end_time = time.time()
145+
logger.info(
146+
"Train Dataset Done, using time {} second.".format(train_end_time - train_start_time))
147+
148+
model_dir = "{}/{}/{}".format(save_model_path, day, pass_index)
149+
150+
if fleet.is_first_worker() and save_model_path and is_distributed_env():
151+
fleet.save_inference_model(
152+
self.exe, model_dir,
153+
[feed.name for feed in self.input_data],
154+
self.inference_target_var,
155+
mode=2)
156+
157+
if fleet.is_first_worker() and save_model_path and is_distributed_env():
158+
fleet.save_inference_model(
159+
self.exe, model_dir,
160+
[feed.name for feed in self.input_data],
161+
self.inference_target_var,
162+
mode=0)
163+
164+
def dataset_train_loop(self, cur_dataset, day, pass_index):
165+
logger.info("Day: {} Pass: {}, Running Dataset Begin.".format(day, pass_index))
166+
fetch_info = [
167+
"Day: {} Pass: {} Var {}".format(day, pass_index, var_name)
168+
for var_name in self.metrics
169+
]
170+
fetch_vars = [var for _, var in self.metrics.items()]
171+
print_step = int(config.get("runner.print_interval"))
172+
self.exe.train_from_dataset(
173+
program=paddle.static.default_main_program(),
174+
dataset=cur_dataset,
175+
fetch_list=fetch_vars,
176+
fetch_info=fetch_info,
177+
print_period=print_step,
178+
debug=config.get("runner.dataset_debug"))
179+
cur_dataset.release_memory()
180+
181+
if __name__ == "__main__":
182+
paddle.enable_static()
183+
config = parse_args()
184+
# os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
185+
benchmark_main = Main(config)
186+
benchmark_main.run()

0 commit comments

Comments
 (0)