Skip to content

Commit c3f5807

Browse files
committed
🎉 first commit
0 parents  commit c3f5807

14 files changed

+847
-0
lines changed

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.pythonPath": "/home/xzw/anaconda3/bin/python"
3+
}

Readme.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# GAIL
2+
3+
This project is implement under two classical control problem: *Cartpole* and *Pendulum*, which represent discrete and continuous case respectively.
4+
5+
* First collect the expert trajectories by the PPO algorithm.
6+
* Then utilize these expert trajectories to imitate them with GAIL.
7+
* The paper use TRPO to optimize the policy net, however I use **PPO** with **GAE** here.

__pycache__/gail.cpython-37.pyc

5.3 KB
Binary file not shown.

__pycache__/net.cpython-37.pyc

3.36 KB
Binary file not shown.
1.72 KB
Binary file not shown.

cartpole_test.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import gym
2+
from gail import gail
3+
4+
5+
if __name__ == '__main__':
6+
# * make the performance improve evidently
7+
env = gym.make('CartPole-v0')
8+
file = open('./traj/cartpole.pkl', 'rb')
9+
test = gail(
10+
env=env,
11+
episode=10000000,
12+
capacity=1000,
13+
gamma=0.99,
14+
lam=0.95,
15+
is_disc=True,
16+
value_learning_rate=3e-4,
17+
policy_learning_rate=3e-4,
18+
discriminator_learning_rate=3e-4,
19+
batch_size=64,
20+
file=file,
21+
policy_iter=1,
22+
disc_iter=10,
23+
value_iter=1,
24+
epsilon=0.2,
25+
entropy_weight=1e-4,
26+
train_iter=500,
27+
clip_grad=40,
28+
render=False
29+
)
30+
test.run()

gail.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from replay_buffer import replay_buffer
2+
from net import disc_policy_net, value_net, discriminator, cont_policy_net
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import numpy as np
7+
import pickle
8+
import gym
9+
import random
10+
11+
12+
class gail(object):
13+
def __init__(self, env, episode, capacity, gamma, lam, is_disc, value_learning_rate, policy_learning_rate, discriminator_learning_rate, batch_size, file, policy_iter, disc_iter, value_iter, epsilon, entropy_weight, train_iter, clip_grad, render):
14+
self.env = env
15+
self.episode = episode
16+
self.capacity = capacity
17+
self.gamma = gamma
18+
self.lam = lam
19+
self.is_disc = is_disc
20+
self.value_learning_rate = value_learning_rate
21+
self.policy_learning_rate = policy_learning_rate
22+
self.discriminator_learning_rate = discriminator_learning_rate
23+
self.batch_size = batch_size
24+
self.file = file
25+
self.policy_iter = policy_iter
26+
self.disc_iter = disc_iter
27+
self.value_iter = value_iter
28+
self.epsilon = epsilon
29+
self.entropy_weight = entropy_weight
30+
self.train_iter = train_iter
31+
self.clip_grad = clip_grad
32+
self.render = render
33+
34+
self.observation_dim = self.env.observation_space.shape[0]
35+
if is_disc:
36+
self.action_dim = self.env.action_space.n
37+
else:
38+
self.action_dim = self.env.action_space.shape[0]
39+
if is_disc:
40+
self.policy_net = disc_policy_net(self.observation_dim, self.action_dim)
41+
else:
42+
self.policy_net = cont_policy_net(self.observation_dim, self.action_dim)
43+
self.value_net = value_net(self.observation_dim, 1)
44+
self.discriminator = discriminator(self.observation_dim + self.action_dim)
45+
self.buffer = replay_buffer(self.capacity, self.gamma, self.lam)
46+
self.pool = pickle.load(self.file)
47+
self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.policy_learning_rate)
48+
self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.value_learning_rate)
49+
self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.discriminator_learning_rate)
50+
self.disc_loss_func = nn.BCELoss()
51+
self.weight_reward = None
52+
self.weight_custom_reward = None
53+
54+
def ppo_train(self, ):
55+
observations, actions, returns, advantages = self.buffer.sample(self.batch_size)
56+
observations = torch.FloatTensor(observations)
57+
advantages = torch.FloatTensor(advantages).unsqueeze(1)
58+
advantages = (advantages - advantages.mean()) / advantages.std()
59+
advantages = advantages.detach()
60+
returns = torch.FloatTensor(returns).unsqueeze(1).detach()
61+
62+
for _ in range(self.value_iter):
63+
values = self.value_net.forward(observations)
64+
value_loss = (returns - values).pow(2).mean()
65+
self.value_optimizer.zero_grad()
66+
value_loss.backward()
67+
self.value_optimizer.step()
68+
69+
if self.is_disc:
70+
actions_d = torch.LongTensor(actions).unsqueeze(1)
71+
old_probs = self.policy_net.forward(observations)
72+
old_probs = old_probs.gather(1, actions_d)
73+
dist = torch.distributions.Categorical(old_probs)
74+
entropy = dist.entropy().unsqueeze(1)
75+
for _ in range(self.policy_iter):
76+
probs = self.policy_net.forward(observations)
77+
probs = probs.gather(1, actions_d)
78+
ratio = probs / old_probs.detach()
79+
surr1 = ratio * advantages
80+
surr2 = torch.clamp(ratio, 1. - self.epsilon, 1. + self.epsilon) * advantages
81+
policy_loss = - torch.min(surr1, surr2) - self.entropy_weight * entropy
82+
policy_loss = policy_loss.mean()
83+
self.policy_optimizer.zero_grad()
84+
policy_loss.backward(retain_graph=True)
85+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad)
86+
self.policy_optimizer.step()
87+
else:
88+
actions_c = torch.FloatTensor(actions)
89+
old_dist = self.policy_net.get_distribution(observations)
90+
old_log_probs = old_dist.log_prob(actions_c)
91+
entropy = old_dist.entropy().unsqueeze(1)
92+
for _ in range(self.policy_iter):
93+
dist = self.policy_net.get_distribution(observations)
94+
log_probs = dist.log_prob(actions_c)
95+
ratio = torch.exp(log_probs - old_log_probs.detach())
96+
surr1 = ratio * advantages
97+
surr2 = torch.clamp(ratio, 1. - self.epsilon, 1. + self.epsilon) * advantages
98+
policy_loss = - torch.min(surr1, surr2) - self.entropy_weight * entropy
99+
policy_loss = policy_loss.mean()
100+
self.policy_optimizer.zero_grad()
101+
policy_loss.backward(retain_graph=True)
102+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad)
103+
self.policy_optimizer.step()
104+
105+
def discriminator_train(self):
106+
expert_batch = random.sample(self.pool, self.batch_size)
107+
expert_observations, expert_actions = zip(* expert_batch)
108+
expert_observations = np.vstack(expert_observations)
109+
expert_observations = torch.FloatTensor(expert_observations)
110+
if self.is_disc:
111+
expert_actions_index = torch.LongTensor(expert_actions).unsqueeze(1)
112+
expert_actions = torch.zeros(self.batch_size, self.action_dim)
113+
expert_actions.scatter_(1, expert_actions_index, 1)
114+
else:
115+
expert_actions = torch.FloatTensor(expert_actions).unsqueeze(1)
116+
expert_trajs = torch.cat([expert_observations, expert_actions], 1)
117+
expert_labels = torch.FloatTensor(self.batch_size, 1).fill_(0.0)
118+
119+
observations, actions, _, _ = self.buffer.sample(self.batch_size)
120+
observations = torch.FloatTensor(observations)
121+
if self.is_disc:
122+
actions_index = torch.LongTensor(actions).unsqueeze(1)
123+
actions_dis = torch.zeros(self.batch_size, self.action_dim)
124+
actions_dis.scatter_(1, actions_index, 1)
125+
else:
126+
actions_dis = torch.FloatTensor(actions)
127+
trajs = torch.cat([observations, actions_dis], 1)
128+
labels = torch.FloatTensor(self.batch_size, 1).fill_(1.0)
129+
130+
for _ in range(self.disc_iter):
131+
expert_loss = self.disc_loss_func(self.discriminator.forward(expert_trajs), expert_labels)
132+
current_loss = self.disc_loss_func(self.discriminator.forward(trajs), labels)
133+
134+
loss = (expert_loss + current_loss) / 2
135+
self.discriminator_optimizer.zero_grad()
136+
loss.backward()
137+
self.discriminator_optimizer.step()
138+
139+
def get_reward(self, observation, action):
140+
observation = torch.FloatTensor(np.expand_dims(observation, 0))
141+
if self.is_disc:
142+
action_tensor = torch.zeros(1, self.action_dim)
143+
action_tensor[0, action] = 1.
144+
else:
145+
action_tensor = torch.FloatTensor(action).unsqueeze(1)
146+
traj = torch.cat([observation, action_tensor], 1)
147+
reward = self.discriminator.forward(traj)
148+
reward = - reward.log()
149+
return reward.detach().item()
150+
151+
def run(self):
152+
for i in range(self.episode):
153+
obs = self.env.reset()
154+
if self.render:
155+
self.env.render()
156+
total_reward = 0
157+
total_custom_reward = 0
158+
while True:
159+
action = self.policy_net.act(torch.FloatTensor(np.expand_dims(obs, 0)))
160+
if not self.is_disc:
161+
action = [action]
162+
next_obs, reward, done, _ = self.env.step(action)
163+
custom_reward = self.get_reward(obs, action)
164+
value = self.value_net.forward(torch.FloatTensor(np.expand_dims(obs, 0))).detach().item()
165+
self.buffer.store(obs, action, custom_reward, done, value)
166+
total_reward += reward
167+
total_custom_reward += custom_reward
168+
obs = next_obs
169+
if self.render:
170+
self.env.render()
171+
172+
if done:
173+
if not self.weight_reward:
174+
self.weight_reward = total_reward
175+
else:
176+
self.weight_reward = 0.99 * self.weight_reward + 0.01 * total_reward
177+
if not self.weight_custom_reward:
178+
self.weight_custom_reward = total_custom_reward
179+
else:
180+
self.weight_custom_reward = 0.99 * self.weight_custom_reward + 0.01 * total_custom_reward
181+
if len(self.buffer) >= self.train_iter:
182+
self.buffer.process()
183+
self.discriminator_train()
184+
self.ppo_train()
185+
self.buffer.clear()
186+
print('episode: {} reward: {:.2f} custom_reward: {:.3f} weight_reward: {:.2f} weight_custom_reward: {:.4f}'.format(i + 1, total_reward, total_custom_reward, self.weight_reward, self.weight_custom_reward))
187+
break

net.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import torch.nn.functional as F
5+
from torch.distributions import Categorical, Normal
6+
7+
8+
class disc_policy_net(nn.Module):
9+
def __init__(self, input_dim, output_dim):
10+
super(disc_policy_net, self).__init__()
11+
self.input_dim = input_dim
12+
self.output_dim = output_dim
13+
self.fc1 = nn.Linear(self.input_dim, 128)
14+
self.fc2 = nn.Linear(128, 128)
15+
self.fc3 = nn.Linear(128, self.output_dim)
16+
17+
def forward(self, input):
18+
x = F.relu(self.fc1(input))
19+
x = F.relu(self.fc2(x))
20+
x = self.fc3(x)
21+
return F.softmax(x, 1)
22+
23+
def act(self, input):
24+
probs = self.forward(input)
25+
dist = Categorical(probs)
26+
action = dist.sample()
27+
action = action.detach().item()
28+
return action
29+
30+
31+
class cont_policy_net(nn.Module):
32+
def __init__(self, input_dim, output_dim):
33+
super(cont_policy_net, self).__init__()
34+
self.input_dim = input_dim
35+
self.output_dim = output_dim
36+
self.fc1 = nn.Linear(self.input_dim, 128)
37+
self.fc2 = nn.Linear(128, 128)
38+
self.fc3 = nn.Linear(128, self.output_dim)
39+
40+
def forward(self, input):
41+
x = torch.tanh(self.fc1(input))
42+
x = torch.tanh(self.fc2(x))
43+
mu = self.fc3(x)
44+
return mu
45+
46+
def act(self, input):
47+
mu = self.forward(input)
48+
sigma = torch.ones_like(mu)
49+
dist = Normal(mu, sigma)
50+
action = dist.sample().detach().item()
51+
return action
52+
53+
def get_distribution(self, input):
54+
mu = self.forward(input)
55+
sigma = torch.ones_like(mu)
56+
dist = Normal(mu, sigma)
57+
return dist
58+
59+
60+
class value_net(nn.Module):
61+
def __init__(self, input_dim, output_dim):
62+
super(value_net, self).__init__()
63+
self.input_dim = input_dim
64+
self.output_dim = output_dim
65+
66+
self.fc1 = nn.Linear(self.input_dim, 128)
67+
self.fc2 = nn.Linear(128, 128)
68+
self.fc3 = nn.Linear(128, self.output_dim)
69+
70+
def forward(self, input):
71+
x = F.relu(self.fc1(input))
72+
x = F.relu(self.fc2(x))
73+
x = self.fc3(x)
74+
return x
75+
76+
77+
class discriminator(nn.Module):
78+
def __init__(self, input_dim):
79+
super(discriminator, self).__init__()
80+
self.input_dim = input_dim
81+
82+
self.model = nn.Sequential(
83+
nn.Linear(self.input_dim, 128),
84+
nn.ReLU(),
85+
nn.Linear(128, 128),
86+
nn.ReLU(),
87+
nn.Linear(128, 1),
88+
nn.Sigmoid()
89+
)
90+
91+
def forward(self, input):
92+
return self.model(input)

pendulum_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import gym
2+
from gail import gail
3+
4+
5+
if __name__ == '__main__':
6+
# * the GAIL doesn't perform well in continuous case
7+
# * (maybe only in this case under these hyperparameters)
8+
# * exist ocillation phenomenon and can't converge
9+
env = gym.make('Pendulum-v0')
10+
file = open('./traj/pendulum.pkl', 'rb')
11+
test = gail(
12+
env=env,
13+
episode=10000000,
14+
capacity=1000,
15+
gamma=0.99,
16+
lam=0.95,
17+
is_disc=False,
18+
value_learning_rate=1e-4,
19+
policy_learning_rate=1e-4,
20+
discriminator_learning_rate=3e-4,
21+
batch_size=64,
22+
file=file,
23+
policy_iter=3,
24+
disc_iter=10,
25+
value_iter=3,
26+
epsilon=0.05,
27+
entropy_weight=0,
28+
train_iter=600,
29+
clip_grad=0.2,
30+
render=False
31+
)
32+
test.run()

0 commit comments

Comments
 (0)