Skip to content

Commit b6e88df

Browse files
committed
Fix the buffer bug and add td3 for elegantrl
1 parent 739b131 commit b6e88df

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

rofunc/examples/learning/CURICabinet_elegantrl.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,25 @@ def eval(custom_args, ckpt_path=None):
2828
# TODO: add support for eval mode
2929
beauty_print("Start evaluating")
3030

31-
env, args = setup(custom_args, eval_mode=True)
31+
env, agent = setup(custom_args, eval_mode=True)
3232

3333
# load checkpoint
3434
if ckpt_path is None:
3535
ckpt_path = model_zoo(name="CURICabinetPPO_right_arm.pt")
36-
agent.load(ckpt_path)
36+
agent.save_or_load_agent(cwd=ckpt_path, if_save=False)
3737

3838
# evaluate the agent
39-
trainer.eval()
39+
state = env.reset()
40+
episode_reward = 0
41+
for i in range(2 ** 10):
42+
action = agent.act.get_action(state).detach()
43+
next_state, reward, done, _ = env.step(action)
44+
episode_reward += reward.mean()
45+
# if done:
46+
# print(f'Step {i:>6}, Episode return {episode_reward:8.3f}')
47+
# break
48+
# else:
49+
state = next_state
4050

4151

4252
if __name__ == '__main__':
@@ -47,15 +57,12 @@ def eval(custom_args, ckpt_path=None):
4757
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
4858
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
4959
parser.add_argument("--graphics_device_id", type=int, default=gpu_id)
50-
parser.add_argument("--headless", type=str, default="False")
60+
parser.add_argument("--headless", type=str, default="True")
5161
parser.add_argument("--test", action="store_true", help="turn to test mode while adding this argument")
5262
custom_args = parser.parse_args()
5363

5464
if not custom_args.test:
5565
train(custom_args)
5666
else:
57-
# TODO: add support for eval mode
58-
folder = 'CURICabinetSAC_22-11-27_18-38-53-296354'
59-
ckpt_path = "/home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/rofunc/examples/learning/runs/{}/checkpoints/best_agent.pt".format(
60-
folder)
67+
ckpt_path = "/home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/rofunc/examples/learning/result/CURICabinet_SAC_42/actor_53608448_00007.742.pth"
6168
eval(custom_args, ckpt_path=ckpt_path)

rofunc/lfd/rl/utils/elegantrl_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from elegantrl.train.config import Arguments
1313
from elegantrl.agents.AgentPPO import AgentPPO
1414
from elegantrl.agents.AgentSAC import AgentSAC
15+
from elegantrl.agents.AgentTD3 import AgentTD3
16+
from elegantrl.train.run import init_agent
1517

1618
from rofunc.config.utils import get_config, omegaconf_to_dict
1719
from rofunc.lfd.rl.tasks import task_map
20+
from rofunc.utils.logger.beauty_logger import beauty_print
1821

1922

2023
class ElegantRLIsaacGymEnvWrapper:
@@ -96,6 +99,7 @@ def step(
9699
def setup(custom_args, eval_mode=False):
97100
# get config
98101
sys.argv.append("task={}".format(custom_args.task))
102+
beauty_print("Agent: {}{}ElegantRL".format(custom_args.task, custom_args.agent.upper()), 2)
99103
sys.argv.append("sim_device={}".format(custom_args.sim_device))
100104
sys.argv.append("rl_device={}".format(custom_args.rl_device))
101105
sys.argv.append("graphics_device_id={}".format(custom_args.graphics_device_id))
@@ -106,6 +110,7 @@ def setup(custom_args, eval_mode=False):
106110

107111
if eval_mode:
108112
task_cfg_dict['env']['numEnvs'] = 16
113+
cfg.headless = False
109114

110115
env = task_map[custom_args.task](cfg=task_cfg_dict,
111116
rl_device=cfg.rl_device,
@@ -121,6 +126,8 @@ def setup(custom_args, eval_mode=False):
121126
agent_class = AgentPPO
122127
elif custom_args.agent.lower() == "sac":
123128
agent_class = AgentSAC
129+
elif custom_args.agent.lower() == "td3":
130+
agent_class = AgentTD3
124131
else:
125132
raise ValueError("Agent not supported")
126133

@@ -143,4 +150,8 @@ def setup(custom_args, eval_mode=False):
143150
args.learner_gpus = cfg.graphics_device_id
144151
args.random_seed = 42
145152

153+
if eval_mode:
154+
agent = init_agent(args, args.learner_gpus, env)
155+
return env, agent
156+
146157
return env, args

0 commit comments

Comments
 (0)