-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrun_open_manipulator_reacher_v0.py
executable file
·77 lines (63 loc) · 2.2 KB
/
run_open_manipulator_reacher_v0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""Train or test algorithms on OpenManipulator Reacher-v0 on Gazebo.
- Author: Kh Kim
- Contact: [email protected]
"""
import argparse
import importlib
from config.environment.open_manipulator import config as env_cfg
import algorithms.common.helper_functions as common_utils
from envs.open_manipulator.open_manipulator_reacher_env import OpenManipulatorReacherEnv
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL algorithms")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument("--algo", type=str, default="td3", help="choose an algorithm")
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from", type=str, help="load the saved model and optimizer at the beginning"
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument("--log", dest="log", action="store_true", help="turn on logging")
parser.add_argument("--save-period", type=int, default=200, help="save model period")
parser.add_argument("--episode-num", type=int, default=20000, help="total episode num")
parser.add_argument(
"--max-episode-steps", type=int, default=-1, help="max episode step"
)
parser.add_argument(
"--demo-path", type=str, default="data/reacher_demo.pkl", help="demonstration path"
)
parser.set_defaults(test=False)
parser.set_defaults(load_from=None)
parser.set_defaults(render=True)
parser.set_defaults(log=False)
args = parser.parse_args()
def main():
"""Main."""
# env initialization
env = OpenManipulatorReacherEnv(env_cfg)
# set a random seed
common_utils.set_random_seed(args.seed, env)
# agent initialization
module_path = "config.agent.open_manipulator_reacher_v0." + args.algo
agent = importlib.import_module(module_path)
agent = agent.get(env, args)
# run
if args.test:
agent.test()
else:
agent.train()
if __name__ == "__main__":
main()