-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_deeplte.py
73 lines (53 loc) · 1.93 KB
/
run_deeplte.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Train."""
import functools
import os
import pathlib
from absl import app, flags, logging
from jaxline import platform
from deeplte.checkpoint import (
restore_state_to_in_memory_checkpointer,
save_state_from_in_memory_checkpointer,
setup_signals,
)
from deeplte.train import Trainer
FLAGS = flags.FLAGS
flags.DEFINE_string("data_path", None, "data path")
def main(experiment_class, argv):
# make figures/ and ckpts/ directories
mkdir(file_dirs=["./figure", "./ckpts"])
write_data_path(FLAGS.config.experiment_kwargs.config.dataset)
# Maybe restore a model.
restore_dir = FLAGS.config.restore_dir
if restore_dir:
restore_state_to_in_memory_checkpointer(restore_dir)
# Maybe save a model.
save_dir = os.path.join(FLAGS.config.checkpoint_dir, "models")
if FLAGS.config.one_off_evaluate:
save_model_fn = (
lambda: None
) # noqa: E731 # No need to save checkpoint in this case.
else:
save_model_fn = functools.partial(
save_state_from_in_memory_checkpointer, save_dir, experiment_class
)
setup_signals(save_model_fn) # Save on Ctrl+C (continue) or Ctrl+\ (exit).
if FLAGS.jaxline_mode.startswith("train"):
if not pathlib.Path(FLAGS.config.checkpoint_dir).exists():
pathlib.Path(FLAGS.config.checkpoint_dir).mkdir()
logging.get_absl_handler().use_absl_log_file(
"train", FLAGS.config.checkpoint_dir
)
try:
platform.main(experiment_class, argv)
finally:
save_model_fn() # Save at the end of training or in case of exception.
def write_data_path(config):
config.data_path = FLAGS.data_path
def mkdir(file_dirs):
for file_dir in file_dirs:
isExists = os.path.exists(file_dir)
if not isExists:
os.makedirs(file_dir)
if __name__ == "__main__":
flags.mark_flag_as_required("config")
app.run(functools.partial(main, Trainer))