Skip to content

Commit ce4d343

Browse files
committed
update ppo
1 parent 796beaf commit ce4d343

9 files changed

+1569
-263
lines changed

ppo_actor

68.9 KB
Binary file not shown.

ppo_continuous.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
'''
2-
PPO
3-
'''
4-
5-
1+
"""
2+
Proximal Policy Optimization (PPO) version 2
3+
----------------------------
4+
2 actors and 1 critic
5+
old policy given by old actor, which is delayed copy of actor
6+
7+
To run
8+
------
9+
python tutorial_PPO.py --train/test
10+
"""
611
import math
712
import random
813

@@ -50,7 +55,7 @@
5055

5156
##################### hyper parameters ####################
5257

53-
ENV_NAME = 'Pendulum-v0' # environment name
58+
ENV_NAME = 'HalfCheetah-v2' # environment name HalfCheetah-v2 Pendulum-v0
5459
RANDOMSEED = 2 # random seed
5560

5661
EP_MAX = 1000 # total number of episodes for training
@@ -78,8 +83,8 @@ def __init__(self, state_dim, hidden_dim, init_w=3e-3):
7883
# self.linear3 = nn.Linear(hidden_dim, hidden_dim)
7984
self.linear4 = nn.Linear(hidden_dim, 1)
8085
# weights initialization
81-
self.linear4.weight.data.uniform_(-init_w, init_w)
82-
self.linear4.bias.data.uniform_(-init_w, init_w)
86+
# self.linear4.weight.data.uniform_(-init_w, init_w)
87+
# self.linear4.bias.data.uniform_(-init_w, init_w)
8388

8489
def forward(self, state):
8590
x = F.relu(self.linear1(state))
@@ -101,24 +106,23 @@ def __init__(self, num_inputs, num_actions, hidden_dim, action_range=1., init_w=
101106
# self.linear4 = nn.Linear(hidden_dim, hidden_dim)
102107

103108
self.mean_linear = nn.Linear(hidden_dim, num_actions)
104-
self.mean_linear.weight.data.uniform_(-init_w, init_w)
105-
self.mean_linear.bias.data.uniform_(-init_w, init_w)
109+
# self.mean_linear.weight.data.uniform_(-init_w, init_w)
110+
# self.mean_linear.bias.data.uniform_(-init_w, init_w)
106111

107112
self.log_std_linear = nn.Linear(hidden_dim, num_actions)
108-
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
109-
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
113+
# self.log_std_linear.weight.data.uniform_(-init_w, init_w)
114+
# self.log_std_linear.bias.data.uniform_(-init_w, init_w)
110115

111116
self.num_actions = num_actions
117+
self.action_range = action_range
112118

113-
114119
def forward(self, state):
115120
x = F.relu(self.linear1(state))
116121
x = F.relu(self.linear2(x))
117122
# x = F.relu(self.linear3(x))
118123
# x = F.relu(self.linear4(x))
119124

120-
mean = self.mean_linear(x)
121-
# mean = F.leaky_relu(self.mean_linear(x))
125+
mean = self.action_range * F.tanh(self.mean_linear(x))
122126
log_std = self.log_std_linear(x)
123127
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
124128

@@ -130,7 +134,11 @@ def get_action(self, state, deterministic=False):
130134
std = log_std.exp()
131135
normal = Normal(0, 1)
132136
z = normal.sample()
133-
action = mean+std*z
137+
if deterministic:
138+
action = mean
139+
else:
140+
action = mean+std*z
141+
action = torch.clamp(action, -self.action_range, self.action_range)
134142
return action.squeeze(0)
135143

136144
def sample_action(self,):
@@ -161,8 +169,8 @@ class PPO(object):
161169
PPO class
162170
'''
163171
def __init__(self, state_dim, action_dim, hidden_dim=512, a_lr=3e-4, c_lr=3e-4):
164-
self.actor = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
165-
self.actor_old = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
172+
self.actor = PolicyNetwork(state_dim, action_dim, hidden_dim, 2.).to(device)
173+
self.actor_old = PolicyNetwork(state_dim, action_dim, hidden_dim, 2.).to(device)
166174
self.critic = ValueNetwork(state_dim, hidden_dim).to(device)
167175
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=a_lr)
168176
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=c_lr)
@@ -270,13 +278,13 @@ def update(self, s, a, r):
270278
for _ in range(C_UPDATE_STEPS):
271279
self.c_train(r, s)
272280

273-
def choose_action(self, s):
281+
def choose_action(self, s, deterministic=False):
274282
'''
275283
Choose action
276284
:param s: state
277285
:return: clipped act
278286
'''
279-
a = self.actor.get_action(s)
287+
a = self.actor.get_action(s, deterministic)
280288
return a.detach().cpu().numpy()
281289

282290
def get_v(self, s):
@@ -288,7 +296,9 @@ def get_v(self, s):
288296
s = s.astype(np.float32)
289297
if s.ndim < 2: s = s[np.newaxis, :]
290298
s = torch.FloatTensor(s).to(device)
291-
return self.critic(s).detach().cpu().numpy()[0, 0]
299+
# return self.critic(s).detach().cpu().numpy()[0, 0]
300+
return self.critic(s).squeeze(0).detach().cpu().numpy()
301+
292302

293303
def save_model(self, path):
294304
torch.save(self.actor.state_dict(), path+'_actor')
@@ -307,7 +317,7 @@ def load_model(self, path):
307317

308318
def main():
309319

310-
env = gym.make(ENV_NAME).unwrapped
320+
env = NormalizedActions(gym.make(ENV_NAME).unwrapped)
311321
state_dim = env.observation_space.shape[0]
312322
action_dim = env.action_space.shape[0]
313323

@@ -341,8 +351,11 @@ def main():
341351
ep_r += r
342352

343353
# update ppo
344-
if (t + 1) % BATCH == 0 or t == EP_LEN - 1:
345-
v_s_ = ppo.get_v(s_)
354+
if (t + 1) % BATCH == 0 or t == EP_LEN - 1 or done:
355+
if done:
356+
v_s_=0
357+
else:
358+
v_s_ = ppo.get_v(s_)[0]
346359
discounted_r = []
347360
for r in buffer['reward'][::-1]:
348361
v_s_ = r + GAMMA * v_s_
@@ -352,6 +365,9 @@ def main():
352365
bs, ba, br = np.vstack(buffer['state']), np.vstack(buffer['action']), np.array(discounted_r)[:, np.newaxis]
353366
buffer['state'], buffer['action'], buffer['reward'] = [], [], []
354367
ppo.update(bs, ba, br)
368+
369+
if done:
370+
break
355371
if ep == 0:
356372
all_ep_r.append(ep_r)
357373
else:
@@ -369,7 +385,7 @@ def main():
369385
plt.cla()
370386
plt.title('PPO')
371387
plt.plot(np.arange(len(all_ep_r)), all_ep_r)
372-
plt.ylim(-2000, 0)
388+
# plt.ylim(-2000, 0)
373389
plt.xlabel('Episode')
374390
plt.ylabel('Moving averaged episode reward')
375391
plt.show()
@@ -384,7 +400,9 @@ def main():
384400
s = env.reset()
385401
for i in range(EP_LEN):
386402
env.render()
387-
s, r, done, _ = env.step(ppo.choose_action(s))
403+
a = ppo.choose_action(s, True)
404+
print(a)
405+
s, r, done, _ = env.step(a)
388406
if done:
389407
break
390408
if __name__ == '__main__':

0 commit comments

Comments
 (0)