-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrun_lunarlander_continuous.py
75 lines (62 loc) · 1.94 KB
/
run_lunarlander_continuous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
"""Train or test baselines on LunarLanderContinuous-v2.
- Author: Curt Park
- Contact: [email protected]
"""
import argparse
import importlib
import gym
import algorithms.common.helper_functions as common_utils
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL baselines")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument("--algo", type=str, default="sac", help="choose an algorithm")
parser.add_argument(
"--load-from",
type=str,
default=None,
help="load the saved model and optimizer at the beginning",
)
parser.add_argument("--episode-num", type=int, default=1500, help="total episode num")
parser.add_argument(
"--max-episode-steps", type=int, default=300, help="max episode step"
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument("--save-period", type=int, default=100, help="save model period")
parser.add_argument("--log", action="store_true", help="turn on logging")
parser.add_argument("--test", action="store_true", help="test mode (no training)")
parser.add_argument(
"--demo-path",
type=str,
default="data/lunarlander_continuous_demo.pkl",
help="demonstration path",
)
parser.set_defaults(render=True)
args = parser.parse_args()
def main():
"""Main."""
# env initialization
env = gym.make("LunarLanderContinuous-v2")
# set a random seed
common_utils.set_random_seed(args.seed, env)
# run
module_path = "config.agent.lunarlander_continuous_v2." + args.algo
agent = importlib.import_module(module_path)
agent = agent.get(env, args)
# run
if args.test:
agent.test()
else:
agent.train()
if __name__ == "__main__":
main()