-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrollout.py
188 lines (159 loc) · 7.6 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from collections import deque
import numpy as np
import pickle
from mujoco_py import MujocoException
from util import convert_episode_to_batch_major, store_args
class RolloutWorker:
@store_args
def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1,
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
random_eps=0, history_len=100, render=False, **kwargs):
"""Rollout worker generates experience by interacting with one or many environments.
Args:
make_env (function): a factory function that creates a new instance of the environment
when called
policy (object): the policy that is used to act
dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u)
logger (object): the logger that is used by the rollout worker
rollout_batch_size (int): the number of parallel rollouts that should be used
exploit (boolean): whether or not to exploit, i.e. to act optimally according to the
current policy without any exploration
use_target_net (boolean): whether or not to use the target net for rollouts
compute_Q (boolean): whether or not to compute the Q values alongside the actions
noise_eps (float): scale of the additive Gaussian noise
random_eps (float): probability of selecting a completely random action
history_len (int): length of history for statistics smoothing
render (boolean): whether or not to render the rollouts
"""
self.envs = [make_env() for _ in range(rollout_batch_size)]
assert self.T > 0
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
self.success_history = deque(maxlen=history_len)
self.Q_history = deque(maxlen=history_len)
self.n_episodes = 0
self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals
self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
self.reset_all_rollouts()
self.clear_history()
def reset_rollout(self, i):
"""Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o`
and `g` arrays accordingly.
"""
obs = self.envs[i].reset()
self.initial_o[i] = obs['observation']
self.initial_ag[i] = obs['achieved_goal']
self.g[i] = obs['desired_goal']
def reset_all_rollouts(self):
"""Resets all `rollout_batch_size` rollout workers.
"""
for i in range(self.rollout_batch_size):
self.reset_rollout(i)
def generate_rollouts(self):
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
policy acting on it accordingly.
"""
self.reset_all_rollouts()
# compute observations
o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
o[:] = self.initial_o
ag[:] = self.initial_ag
# generate episodes
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
Qs = []
for t in range(self.T):
policy_output = self.policy.get_actions( #
o, ag, self.g,
compute_Q=self.compute_Q,
# noise_eps=self.noise_eps if not self.exploit else 0.,
# random_eps=self.random_eps if not self.exploit else 0.,
use_target_net=self.use_target_net)
if self.compute_Q:
u, Q = policy_output
Qs.append(Q)
else:
u = policy_output
if u.ndim == 1:
# The non-batched case should still have a reasonable shape.
u = u.reshape(1, -1)
o_new = np.empty((self.rollout_batch_size, self.dims['o']))
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
success = np.zeros(self.rollout_batch_size)
# compute new states and observations
for i in range(self.rollout_batch_size):
try:
# We fully ignore the reward here because it will have to be re-computed
# for HER.
curr_o_new, _, _, info = self.envs[i].step(u[i])
if 'is_success' in info:
success[i] = info['is_success']
o_new[i] = curr_o_new['observation']
ag_new[i] = curr_o_new['achieved_goal']
for idx, key in enumerate(self.info_keys):
info_values[idx][t, i] = info[key]
if self.render:
self.envs[i].render()
except MujocoException as e:
return self.generate_rollouts()
if np.isnan(o_new).any():
self.logger.warning('NaN caught during rollout generation. Trying again...')
self.reset_all_rollouts()
return self.generate_rollouts()
obs.append(o.copy())
achieved_goals.append(ag.copy())
successes.append(success.copy())
acts.append(u.copy())
goals.append(self.g.copy())
o[...] = o_new
ag[...] = ag_new
obs.append(o.copy())
achieved_goals.append(ag.copy())
self.initial_o[:] = o
episode = dict(o=obs,
u=acts,
g=goals,
ag=achieved_goals)
for key, value in zip(self.info_keys, info_values):
episode['info_{}'.format(key)] = value
# stats
successful = np.array(successes)[-1, :]
assert successful.shape == (self.rollout_batch_size,)
success_rate = np.mean(successful)
self.success_history.append(success_rate)
if self.compute_Q:
self.Q_history.append(np.mean(Qs))
self.n_episodes += self.rollout_batch_size
return convert_episode_to_batch_major(episode)
def clear_history(self):
"""Clears all histories that are used for statistics
"""
self.success_history.clear()
self.Q_history.clear()
def current_success_rate(self):
return np.mean(self.success_history)
def current_mean_Q(self):
return np.mean(self.Q_history)
def save_policy(self, path):
"""Pickles the current policy for later inspection.
"""
with open(path, 'wb') as f:
pickle.dump(self.policy, f)
def logs(self, prefix='worker'):
"""Generates a dictionary that contains all collected statistics.
"""
logs = []
logs += [('success_rate', np.mean(self.success_history))]
if self.compute_Q:
logs += [('mean_Q', np.mean(self.Q_history))]
logs += [('episode', self.n_episodes)]
if prefix is not '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs
def seed(self, seed):
"""Seeds each environment with a distinct seed derived from the passed in global seed.
"""
for idx, env in enumerate(self.envs):
env.seed(seed + 1000 * idx)