Skip to content

Commit e10d986

Browse files
authored
Merge pull request #4642 from buckman-google/master
Addition of STEVE
2 parents ee0e9d1 + f789dcf commit e10d986

File tree

166 files changed

+2804
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

166 files changed

+2804
-0
lines changed

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
/research/seq2species/ @apbusia @depristo
4242
/research/skip_thoughts/ @cshallue
4343
/research/slim/ @sguada @nathansilberman
44+
/research/steve/ @buckman-google
4445
/research/street/ @theraysmith
4546
/research/swivel/ @waterson
4647
/research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick

research/steve/README.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Stochastic Ensemble Value Expansion
2+
3+
*A hybrid model-based/model-free reinforcement learning algorithm for sample-efficient continuous control.*
4+
5+
This is the code repository accompanying the paper Sample-Efficient Reinforcement Learning with
6+
Stochastic Ensemble Value Expansion, by Buckman et al. (2018).
7+
8+
#### Abstract:
9+
Merging model-free and model-based approaches in reinforcement learning has the potential to achieve
10+
the high performance of model-free algorithms with low sample complexity. This is difficult because
11+
an imperfect dynamics model can degrade the performance of the learning algorithm, and in sufficiently
12+
complex environments, the dynamics model will always be imperfect. As a result, a key challenge is to
13+
combine model-based approaches with model-free learning in such a way that errors in the model do not
14+
degrade performance. We propose *stochastic ensemble value expansion* (STEVE), a novel model-based
15+
technique that addresses this issue. By dynamically interpolating between model rollouts of various horizon
16+
lengths for each individual example, STEVE ensures that the model is only utilized when doing so does not
17+
introduce significant errors. Our approach outperforms model-free baselines on challenging continuous
18+
control benchmarks with an order-of-magnitude increase in sample efficiency, and in contrast to previous
19+
model-based approaches, performance does not degrade as the environment gets more complex.
20+
21+
## Installation
22+
This code is compatible with Ubuntu 16.04 and Python 2.7. There are several prerequisites:
23+
* Numpy, Scipy, and Portalocker: `pip install numpy scipy portalocker`
24+
* TensorFlow 1.6 or above. Instructions can be found on the official TensorFlow page:
25+
[https://www.tensorflow.org/install/install_linux](https://www.tensorflow.org/install/install_linux).
26+
We suggest installing the GPU version of TensorFlow to speed up training.
27+
* OpenAI Gym version 0.9.4. Instructions can be found in the OpenAI Gym repository:
28+
[https://github.com/openai/gym#installation](https://github.com/openai/gym#installation).
29+
Note that you need to replace "pip install gym[all]" with "pip install gym[all]==0.9.4", which
30+
will ensure that you get the correct version of Gym. (The current version of Gym has deprecated
31+
the -v1 MuJoCo environments, which are the environments studied in this paper.)
32+
* MuJoCo version 1.31, which can be downloaded here: [https://www.roboti.us/download/mjpro131_linux.zip](https://www.roboti.us/download/mjpro131_linux.zip).
33+
Simply run: ```
34+
cd ~; mkdir -p .mujoco; cd .mujoco/; wget https://www.roboti.us/download/mjpro131_linux.zip; unzip mjpro131_linux.zip```
35+
You also need to get a license, and put the license key in ~/.mujoco/ as well.
36+
* Optionally, Roboschool version 1.1. This is needed only to replicate the Roboschool experiments.
37+
Instructions can be found in the OpenAI Roboschool repository:
38+
[https://github.com/openai/roboschool#installation](https://github.com/openai/roboschool#installation).
39+
* Optionally, MoviePy to render trained agents. Instructions on the MoviePy homepage:
40+
[https://zulko.github.io/moviepy/install.html](https://zulko.github.io/moviepy/install.html).
41+
42+
## Running Experiments
43+
To run an experiment, run master.py and pass in a config file and GPU ID. For example: ```
44+
python master.py config/experiments/speedruns/humanoid/speedy_steve0.json 0```
45+
The `config/experiments/`
46+
directory contains configuration files for all of the experiments run in the paper.
47+
48+
The GPU ID specifies the GPU that should be used to learn the policy. For model-based approaches, the
49+
next GPU (i.e. GPU_ID+1) is used to learn the worldmodel in parallel.
50+
51+
To resume an experiment that was interrupted, use the same config file and pass the `--resume` flag: ```
52+
python master.py config/experiments/speedruns/humanoid/speedy_steve0.json 0 --resume```
53+
54+
## Output
55+
For each experiment, two folders are created in the output directory: `<ENVIRONMENT>/<EXPERIMENT>/log`
56+
and `<ENVIRONMENT>/<EXPERIMENT>/checkpoints`. The log directory contains the following:
57+
58+
* `hps.json` contains the accumulated hyperparameters of the config file used to generate these results
59+
* `valuerl.log` and `worldmodel.log` contain the log output of the learners. `worldmodel.log` will not
60+
exist if you are not learning a worldmodel.
61+
* `<EXPERIMENT>.greedy.csv` records all of the scores of our evaluators. The four columns contain time (hours),
62+
epochs, frames, and score.
63+
64+
The checkpoints directory contains the most recent versions of the policy and worldmodel, as well as checkpoints
65+
of the policy, worldmodel, and their respective replay buffers at various points throughout training.
66+
67+
## Code Organization
68+
`master.py` launches four types of processes: a ValueRlLearner to learn the policy, a WorldmodelLearner
69+
to learn the dynamics model, several Interactors to gather data from the environment to train on, and
70+
a few Evaluators to run the greedy policy in the environment and record the score.
71+
72+
`learner.py` contains a general framework for models which learn from a replay buffer. This is where
73+
most of the code for the overall training loop is located. `valuerl_learner.py` and `worldmodel_learner.py`
74+
contain a small amount of model-specific training loop code.
75+
76+
`valuerl.py` implements the core model for all value-function-based policy learning techniques studied
77+
in the paper, including DDPG, MVE, STEVE, etc. Similarly, `worldmodel.py` contains the core model for
78+
our dynamics model and reward function.
79+
80+
`replay.py` contains the code for the replay buffer. `nn.py`, `envwrap.py`, `config.py`, and `util.py`
81+
each contain various helper functions.
82+
83+
`toy_demo.py` is a self-contained demo, written in numpy, that was used to generate the results for the
84+
toy examples in the first segment of the paper.
85+
86+
`visualizer.py` is a utility script for loading trained policies and inspecting them. In addition to a
87+
config file and a GPU, it takes the filename of the model to load as a mandatory third argument.
88+
89+
## Contact
90+
Please contact GitHub user buckman-google ([email protected]) with any questions.

research/steve/agent.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from __future__ import print_function
2+
from builtins import zip
3+
from builtins import range
4+
from builtins import object
5+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
# ==============================================================================
19+
20+
import numpy as np
21+
import tensorflow as tf
22+
import time, os, traceback, multiprocessing, portalocker
23+
24+
import envwrap
25+
import valuerl
26+
import util
27+
from config import config
28+
29+
30+
def run_env(pipe):
31+
env = envwrap.get_env(config["env"]["name"])
32+
reset = True
33+
while True:
34+
if reset is True: pipe.send(env.reset())
35+
action = pipe.recv()
36+
obs, reward, done, reset = env.step(action)
37+
pipe.send((obs, reward, done, reset))
38+
39+
class AgentManager(object):
40+
"""
41+
Interact with the environment according to the learned policy,
42+
"""
43+
def __init__(self, proc_num, evaluation, policy_lock, batch_size, config):
44+
self.evaluation = evaluation
45+
self.policy_lock = policy_lock
46+
self.batch_size = batch_size
47+
self.config = config
48+
49+
self.log_path = util.create_directory("%s/%s/%s/%s" % (config["output_root"], config["env"]["name"], config["name"], config["log_path"])) + "/%s" % config["name"]
50+
self.load_path = util.create_directory("%s/%s/%s/%s" % (config["output_root"], config["env"]["name"], config["name"], config["save_model_path"]))
51+
52+
## placeholders for intermediate states (basis for rollout)
53+
self.obs_loader = tf.placeholder(tf.float32, [self.batch_size, np.prod(self.config["env"]["obs_dims"])])
54+
55+
## build model
56+
self.valuerl = valuerl.ValueRL(self.config["name"], self.config["env"], self.config["policy_config"])
57+
self.policy_actions = self.valuerl.build_evalution_graph(self.obs_loader, mode="exploit" if self.evaluation else "explore")
58+
59+
# interactors
60+
self.agent_pipes, self.agent_child_pipes = list(zip(*[multiprocessing.Pipe() for _ in range(self.batch_size)]))
61+
self.agents = [multiprocessing.Process(target=run_env, args=(self.agent_child_pipes[i],)) for i in range(self.batch_size)]
62+
for agent in self.agents: agent.start()
63+
self.obs = [pipe.recv() for pipe in self.agent_pipes]
64+
self.total_rewards = [0. for _ in self.agent_pipes]
65+
self.loaded_policy = False
66+
67+
self.sess = tf.Session()
68+
self.sess.run(tf.global_variables_initializer())
69+
70+
self.rollout_i = 0
71+
self.proc_num = proc_num
72+
self.epoch = -1
73+
self.frame_total = 0
74+
self.hours = 0.
75+
76+
self.first = True
77+
78+
def get_action(self, obs):
79+
if self.loaded_policy:
80+
all_actions = self.sess.run(self.policy_actions, feed_dict={self.obs_loader: obs})
81+
all_actions = np.clip(all_actions, -1., 1.)
82+
return all_actions[:self.batch_size]
83+
else:
84+
return [self.get_random_action() for _ in range(obs.shape[0])]
85+
86+
def get_random_action(self, *args, **kwargs):
87+
return np.random.random(self.config["env"]["action_dim"]) * 2 - 1
88+
89+
def step(self):
90+
actions = self.get_action(np.stack(self.obs))
91+
self.first = False
92+
[pipe.send(action) for pipe, action in zip(self.agent_pipes, actions)]
93+
next_obs, rewards, dones, resets = list(zip(*[pipe.recv() for pipe in self.agent_pipes]))
94+
95+
frames = list(zip(self.obs, next_obs, actions, rewards, dones))
96+
97+
self.obs = [o if resets[i] is False else self.agent_pipes[i].recv() for i, o in enumerate(next_obs)]
98+
99+
for i, (t,r,reset) in enumerate(zip(self.total_rewards, rewards, resets)):
100+
if reset:
101+
self.total_rewards[i] = 0.
102+
if self.evaluation and self.loaded_policy:
103+
with portalocker.Lock(self.log_path+'.greedy.csv', mode="a") as f: f.write("%2f,%d,%d,%2f\n" % (self.hours, self.epoch, self.frame_total, t+r))
104+
105+
else:
106+
self.total_rewards[i] = t + r
107+
108+
if self.evaluation and np.any(resets): self.reload()
109+
110+
self.rollout_i += 1
111+
return frames
112+
113+
def reload(self):
114+
if not os.path.exists("%s/%s.params.index" % (self.load_path ,self.valuerl.saveid)): return False
115+
with self.policy_lock:
116+
self.valuerl.load(self.sess, self.load_path)
117+
self.epoch, self.frame_total, self.hours = self.sess.run([self.valuerl.epoch_n, self.valuerl.frame_n, self.valuerl.hours])
118+
self.loaded_policy = True
119+
self.first = True
120+
return True
121+
122+
def main(proc_num, evaluation, policy_replay_frame_queue, model_replay_frame_queue, policy_lock, config):
123+
try:
124+
np.random.seed((proc_num * int(time.time())) % (2 ** 32 - 1))
125+
agentmanager = AgentManager(proc_num, evaluation, policy_lock, config["evaluator_config"]["batch_size"] if evaluation else config["agent_config"]["batch_size"], config)
126+
frame_i = 0
127+
while True:
128+
new_frames = agentmanager.step()
129+
if not evaluation:
130+
policy_replay_frame_queue.put(new_frames)
131+
if model_replay_frame_queue is not None: model_replay_frame_queue.put(new_frames)
132+
if frame_i % config["agent_config"]["reload_every_n"] == 0: agentmanager.reload()
133+
frame_i += len(new_frames)
134+
135+
except Exception as e:
136+
print('Caught exception in agent process %d' % proc_num)
137+
traceback.print_exc()
138+
print()
139+
try:
140+
for i in agentmanager.agents: i.join()
141+
except:
142+
pass
143+
raise e

research/steve/config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import print_function
2+
from builtins import str
3+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ==============================================================================
17+
18+
import argparse, json, util, traceback
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("config")
22+
parser.add_argument("root_gpu", type=int)
23+
parser.add_argument("--resume", action="store_true")
24+
args = parser.parse_args()
25+
26+
config_loc = args.config
27+
config = util.ConfigDict(config_loc)
28+
29+
config["name"] = config_loc.split("/")[-1][:-5]
30+
config["resume"] = args.resume
31+
32+
cstr = str(config)
33+
34+
def log_config():
35+
HPS_PATH = util.create_directory("output/" + config["env"]["name"] + "/" + config["name"] + "/" + config["log_path"]) + "/hps.json"
36+
print("ROOT GPU: " + str(args.root_gpu) + "\n" + str(cstr))
37+
with open(HPS_PATH, "w") as f:
38+
f.write("ROOT GPU: " + str(args.root_gpu) + "\n" + str(cstr))

research/steve/config/algos/ddpg.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"inherits": ["config/core/basic.json"]
3+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"inherits": [
3+
"config/core/basic.json",
4+
"config/core/model.json"
5+
],
6+
"updates":{
7+
"policy_config": {
8+
"value_expansion": {
9+
"rollout_len": 3,
10+
"mean_k_return": true
11+
}
12+
}
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"inherits": [
3+
"config/core/basic.json",
4+
"config/core/model.json"
5+
],
6+
"updates":{
7+
"policy_config": {
8+
"value_expansion": {
9+
"rollout_len": 3,
10+
"tdk_trick": true
11+
}
12+
}
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"inherits": [
3+
"config/core/basic.json",
4+
"config/core/model.json"
5+
],
6+
"updates":{
7+
"policy_config": {
8+
"value_expansion": {
9+
"rollout_len": 3,
10+
"lambda_return": 0.25
11+
}
12+
}
13+
}
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"inherits": [
3+
"config/core/basic.json",
4+
"config/core/model.json",
5+
"config/core/bayesian.json"
6+
],
7+
"updates":{
8+
"policy_config": {
9+
"value_expansion": {
10+
"rollout_len": 3,
11+
"steve_reweight": true
12+
}
13+
}
14+
}
15+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"inherits": [
3+
"config/core/basic.json",
4+
"config/core/model.json",
5+
"config/core/bayesian.json"
6+
],
7+
"updates":{
8+
"policy_config": {
9+
"value_expansion": {
10+
"rollout_len": 3,
11+
"steve_reweight": true,
12+
"covariances": true
13+
}
14+
}
15+
}
16+
}

0 commit comments

Comments
 (0)