Skip to content

Commit 2d21055

Browse files
committed
Add distributed training.
1 parent 8aaccbc commit 2d21055

File tree

7 files changed

+159
-0
lines changed

7 files changed

+159
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ char_vocab*.txt
77
glove*.txt
88
glove*.txt.filtered
99
*.v*_*_conll
10+
*.hdf5

continuous_evaluate.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import os
7+
import re
8+
import time
9+
import shutil
10+
11+
import tensorflow as tf
12+
import coref_model as cm
13+
import util
14+
15+
def copy_checkpoint(source, target):
16+
for ext in (".index", ".data-00000-of-00001"):
17+
shutil.copyfile(source + ext, target + ext)
18+
19+
if __name__ == "__main__":
20+
config = util.initialize_from_env()
21+
model = cm.CorefModel(config)
22+
23+
saver = tf.train.Saver()
24+
log_dir = config["log_dir"]
25+
writer = tf.summary.FileWriter(log_dir, flush_secs=20)
26+
evaluated_checkpoints = set()
27+
max_f1 = 0
28+
checkpoint_pattern = re.compile(".*model.ckpt-([0-9]*)\Z")
29+
30+
with tf.Session() as session:
31+
while True:
32+
ckpt = tf.train.get_checkpoint_state(log_dir)
33+
if ckpt and ckpt.model_checkpoint_path and ckpt.model_checkpoint_path not in evaluated_checkpoints:
34+
print("Evaluating {}".format(ckpt.model_checkpoint_path))
35+
36+
# Move it to a temporary location to avoid being deleted by the training supervisor.
37+
tmp_checkpoint_path = os.path.join(log_dir, "model.tmp.ckpt")
38+
copy_checkpoint(ckpt.model_checkpoint_path, tmp_checkpoint_path)
39+
40+
global_step = int(checkpoint_pattern.match(ckpt.model_checkpoint_path).group(1))
41+
saver.restore(session, ckpt.model_checkpoint_path)
42+
43+
eval_summary, f1 = model.evaluate(session)
44+
45+
if f1 > max_f1:
46+
max_f1 = f1
47+
copy_checkpoint(tmp_checkpoint_path, os.path.join(log_dir, "model.max.ckpt"))
48+
49+
print("Current max F1: {:.2f}".format(max_f1))
50+
51+
writer.add_summary(eval_summary, global_step)
52+
print("Evaluation written to {} at step {}".format(log_dir, global_step))
53+
54+
evaluated_checkpoints.add(ckpt.model_checkpoint_path)
55+
sleep_time = 60
56+
else:
57+
sleep_time = 10
58+
print("Waiting for {} seconds before looking for next checkpoint.".format(sleep_time))
59+
time.sleep(sleep_time)

evaluate.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
from __future__ import absolute_import
23
from __future__ import division
34
from __future__ import print_function

experiments.conf

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ glove_300d_2w {
1212
size = 300
1313
}
1414

15+
# Distributed training configurations.
16+
two_local_gpus {
17+
addresses {
18+
ps = [localhost:2222]
19+
worker = [localhost:2223, localhost:2224]
20+
}
21+
gpus = [0, 1]
22+
}
23+
1524
# Main configuration.
1625
best {
1726
# Computation limits.
@@ -59,6 +68,7 @@ best {
5968
eval_frequency = 5000
6069
report_frequency = 100
6170
log_root = logs
71+
cluster = ${two_local_gpus}
6272
}
6373

6474
# For evaluation. Do not use for training (i.e. only for predict.py, evaluate.py, and demo.py). Rename `best` directory to `final`.

ps.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/env python
2+
3+
import os
4+
5+
import tensorflow as tf
6+
import util
7+
8+
if __name__ == "__main__":
9+
config = util.initialize_from_env()
10+
report_frequency = config["report_frequency"]
11+
cluster_config = config["cluster"]
12+
util.set_gpus()
13+
cluster = tf.train.ClusterSpec(cluster_config["addresses"])
14+
server = tf.train.Server(cluster, job_name="ps", task_index=0)
15+
server.join()

train.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#!/usr/bin/env python
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
25

36
import os
47
import time

worker.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import os
7+
import sys
8+
import time
9+
10+
import tensorflow as tf
11+
import coref_model as cm
12+
import util
13+
14+
if __name__ == "__main__":
15+
config = util.initialize_from_env()
16+
task_index = int(os.environ["TASK"])
17+
18+
report_frequency = config["report_frequency"]
19+
cluster_config = config["cluster"]
20+
21+
util.set_gpus(cluster_config["gpus"][task_index])
22+
23+
cluster = tf.train.ClusterSpec(cluster_config["addresses"])
24+
server = tf.train.Server(cluster,
25+
job_name="worker",
26+
task_index=task_index)
27+
28+
# Assigns ops to the local worker by default.
29+
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index, cluster=cluster)):
30+
model = cm.CorefModel(config)
31+
saver = tf.train.Saver()
32+
init_op = tf.global_variables_initializer()
33+
34+
log_dir = config["log_dir"]
35+
writer = tf.summary.FileWriter(os.path.join(log_dir, "w{}".format(task_index)), flush_secs=20)
36+
37+
is_chief = (task_index == 0)
38+
39+
# Create a "supervisor", which oversees the training process.
40+
sv = tf.train.Supervisor(is_chief=is_chief,
41+
logdir=log_dir,
42+
init_op=init_op,
43+
saver=saver,
44+
global_step=model.global_step,
45+
save_model_secs=120)
46+
47+
# The supervisor takes care of session initialization, restoring from
48+
# a checkpoint, and closing when done or an error occurs.
49+
with sv.managed_session(server.target) as session:
50+
model.start_enqueue_thread(session)
51+
accumulated_loss = 0.0
52+
initial_time = time.time()
53+
while not sv.should_stop():
54+
tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op])
55+
accumulated_loss += tf_loss
56+
57+
if tf_global_step % report_frequency == 0:
58+
total_time = time.time() - initial_time
59+
steps_per_second = tf_global_step / total_time
60+
61+
average_loss = accumulated_loss / report_frequency
62+
print("[{}] loss={:.2f}, steps/s={:.2f}".format(tf_global_step, tf_loss, steps_per_second))
63+
accumulated_loss = 0.0
64+
writer.add_summary(util.make_summary({
65+
"Train Loss": average_loss,
66+
"Steps per second": steps_per_second
67+
}))
68+
69+
# Ask for all the services to stop.
70+
sv.stop()

0 commit comments

Comments
 (0)