-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrun_reacher_v1.py
74 lines (61 loc) · 1.99 KB
/
run_reacher_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
"""Train or test algorithms on Reacher-v1 of Mujoco.
- Author: Kh Kim
- Contact: [email protected]
"""
import argparse
import importlib
import gym
import algorithms.common.helper_functions as common_utils
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL algorithms")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument("--algo", type=str, default="sac", help="choose an algorithm")
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from", type=str, help="load the saved model and optimizer at the beginning"
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument("--log", dest="log", action="store_true", help="turn on logging")
parser.add_argument("--save-period", type=int, default=200, help="save model period")
parser.add_argument("--episode-num", type=int, default=20000, help="total episode num")
parser.add_argument(
"--max-episode-steps", type=int, default=-1, help="max episode step"
)
parser.add_argument(
"--demo-path", type=str, default="data/reacher_demo.pkl", help="demonstration path"
)
parser.set_defaults(test=False)
parser.set_defaults(load_from=None)
parser.set_defaults(render=True)
parser.set_defaults(log=False)
args = parser.parse_args()
def main():
"""Main."""
# env initialization
env = gym.make("Reacher-v1")
# set a random seed
common_utils.set_random_seed(args.seed, env)
# agent initialization
module_path = "config.agent.reacher-v1." + args.algo
agent = importlib.import_module(module_path)
agent = agent.get(env, args)
# run
if args.test:
agent.test()
else:
agent.train()
if __name__ == "__main__":
main()