|
| 1 | +from replay_buffer import replay_buffer |
| 2 | +from net import disc_policy_net, value_net, discriminator, cont_policy_net |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +import numpy as np |
| 7 | +import pickle |
| 8 | +import gym |
| 9 | +import random |
| 10 | + |
| 11 | + |
| 12 | +class gail(object): |
| 13 | + def __init__(self, env, episode, capacity, gamma, lam, is_disc, value_learning_rate, policy_learning_rate, discriminator_learning_rate, batch_size, file, policy_iter, disc_iter, value_iter, epsilon, entropy_weight, train_iter, clip_grad, render): |
| 14 | + self.env = env |
| 15 | + self.episode = episode |
| 16 | + self.capacity = capacity |
| 17 | + self.gamma = gamma |
| 18 | + self.lam = lam |
| 19 | + self.is_disc = is_disc |
| 20 | + self.value_learning_rate = value_learning_rate |
| 21 | + self.policy_learning_rate = policy_learning_rate |
| 22 | + self.discriminator_learning_rate = discriminator_learning_rate |
| 23 | + self.batch_size = batch_size |
| 24 | + self.file = file |
| 25 | + self.policy_iter = policy_iter |
| 26 | + self.disc_iter = disc_iter |
| 27 | + self.value_iter = value_iter |
| 28 | + self.epsilon = epsilon |
| 29 | + self.entropy_weight = entropy_weight |
| 30 | + self.train_iter = train_iter |
| 31 | + self.clip_grad = clip_grad |
| 32 | + self.render = render |
| 33 | + |
| 34 | + self.observation_dim = self.env.observation_space.shape[0] |
| 35 | + if is_disc: |
| 36 | + self.action_dim = self.env.action_space.n |
| 37 | + else: |
| 38 | + self.action_dim = self.env.action_space.shape[0] |
| 39 | + if is_disc: |
| 40 | + self.policy_net = disc_policy_net(self.observation_dim, self.action_dim) |
| 41 | + else: |
| 42 | + self.policy_net = cont_policy_net(self.observation_dim, self.action_dim) |
| 43 | + self.value_net = value_net(self.observation_dim, 1) |
| 44 | + self.discriminator = discriminator(self.observation_dim + self.action_dim) |
| 45 | + self.buffer = replay_buffer(self.capacity, self.gamma, self.lam) |
| 46 | + self.pool = pickle.load(self.file) |
| 47 | + self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.policy_learning_rate) |
| 48 | + self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.value_learning_rate) |
| 49 | + self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.discriminator_learning_rate) |
| 50 | + self.disc_loss_func = nn.BCELoss() |
| 51 | + self.weight_reward = None |
| 52 | + self.weight_custom_reward = None |
| 53 | + |
| 54 | + def ppo_train(self, ): |
| 55 | + observations, actions, returns, advantages = self.buffer.sample(self.batch_size) |
| 56 | + observations = torch.FloatTensor(observations) |
| 57 | + advantages = torch.FloatTensor(advantages).unsqueeze(1) |
| 58 | + advantages = (advantages - advantages.mean()) / advantages.std() |
| 59 | + advantages = advantages.detach() |
| 60 | + returns = torch.FloatTensor(returns).unsqueeze(1).detach() |
| 61 | + |
| 62 | + for _ in range(self.value_iter): |
| 63 | + values = self.value_net.forward(observations) |
| 64 | + value_loss = (returns - values).pow(2).mean() |
| 65 | + self.value_optimizer.zero_grad() |
| 66 | + value_loss.backward() |
| 67 | + self.value_optimizer.step() |
| 68 | + |
| 69 | + if self.is_disc: |
| 70 | + actions_d = torch.LongTensor(actions).unsqueeze(1) |
| 71 | + old_probs = self.policy_net.forward(observations) |
| 72 | + old_probs = old_probs.gather(1, actions_d) |
| 73 | + dist = torch.distributions.Categorical(old_probs) |
| 74 | + entropy = dist.entropy().unsqueeze(1) |
| 75 | + for _ in range(self.policy_iter): |
| 76 | + probs = self.policy_net.forward(observations) |
| 77 | + probs = probs.gather(1, actions_d) |
| 78 | + ratio = probs / old_probs.detach() |
| 79 | + surr1 = ratio * advantages |
| 80 | + surr2 = torch.clamp(ratio, 1. - self.epsilon, 1. + self.epsilon) * advantages |
| 81 | + policy_loss = - torch.min(surr1, surr2) - self.entropy_weight * entropy |
| 82 | + policy_loss = policy_loss.mean() |
| 83 | + self.policy_optimizer.zero_grad() |
| 84 | + policy_loss.backward(retain_graph=True) |
| 85 | + torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad) |
| 86 | + self.policy_optimizer.step() |
| 87 | + else: |
| 88 | + actions_c = torch.FloatTensor(actions) |
| 89 | + old_dist = self.policy_net.get_distribution(observations) |
| 90 | + old_log_probs = old_dist.log_prob(actions_c) |
| 91 | + entropy = old_dist.entropy().unsqueeze(1) |
| 92 | + for _ in range(self.policy_iter): |
| 93 | + dist = self.policy_net.get_distribution(observations) |
| 94 | + log_probs = dist.log_prob(actions_c) |
| 95 | + ratio = torch.exp(log_probs - old_log_probs.detach()) |
| 96 | + surr1 = ratio * advantages |
| 97 | + surr2 = torch.clamp(ratio, 1. - self.epsilon, 1. + self.epsilon) * advantages |
| 98 | + policy_loss = - torch.min(surr1, surr2) - self.entropy_weight * entropy |
| 99 | + policy_loss = policy_loss.mean() |
| 100 | + self.policy_optimizer.zero_grad() |
| 101 | + policy_loss.backward(retain_graph=True) |
| 102 | + torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad) |
| 103 | + self.policy_optimizer.step() |
| 104 | + |
| 105 | + def discriminator_train(self): |
| 106 | + expert_batch = random.sample(self.pool, self.batch_size) |
| 107 | + expert_observations, expert_actions = zip(* expert_batch) |
| 108 | + expert_observations = np.vstack(expert_observations) |
| 109 | + expert_observations = torch.FloatTensor(expert_observations) |
| 110 | + if self.is_disc: |
| 111 | + expert_actions_index = torch.LongTensor(expert_actions).unsqueeze(1) |
| 112 | + expert_actions = torch.zeros(self.batch_size, self.action_dim) |
| 113 | + expert_actions.scatter_(1, expert_actions_index, 1) |
| 114 | + else: |
| 115 | + expert_actions = torch.FloatTensor(expert_actions).unsqueeze(1) |
| 116 | + expert_trajs = torch.cat([expert_observations, expert_actions], 1) |
| 117 | + expert_labels = torch.FloatTensor(self.batch_size, 1).fill_(0.0) |
| 118 | + |
| 119 | + observations, actions, _, _ = self.buffer.sample(self.batch_size) |
| 120 | + observations = torch.FloatTensor(observations) |
| 121 | + if self.is_disc: |
| 122 | + actions_index = torch.LongTensor(actions).unsqueeze(1) |
| 123 | + actions_dis = torch.zeros(self.batch_size, self.action_dim) |
| 124 | + actions_dis.scatter_(1, actions_index, 1) |
| 125 | + else: |
| 126 | + actions_dis = torch.FloatTensor(actions) |
| 127 | + trajs = torch.cat([observations, actions_dis], 1) |
| 128 | + labels = torch.FloatTensor(self.batch_size, 1).fill_(1.0) |
| 129 | + |
| 130 | + for _ in range(self.disc_iter): |
| 131 | + expert_loss = self.disc_loss_func(self.discriminator.forward(expert_trajs), expert_labels) |
| 132 | + current_loss = self.disc_loss_func(self.discriminator.forward(trajs), labels) |
| 133 | + |
| 134 | + loss = (expert_loss + current_loss) / 2 |
| 135 | + self.discriminator_optimizer.zero_grad() |
| 136 | + loss.backward() |
| 137 | + self.discriminator_optimizer.step() |
| 138 | + |
| 139 | + def get_reward(self, observation, action): |
| 140 | + observation = torch.FloatTensor(np.expand_dims(observation, 0)) |
| 141 | + if self.is_disc: |
| 142 | + action_tensor = torch.zeros(1, self.action_dim) |
| 143 | + action_tensor[0, action] = 1. |
| 144 | + else: |
| 145 | + action_tensor = torch.FloatTensor(action).unsqueeze(1) |
| 146 | + traj = torch.cat([observation, action_tensor], 1) |
| 147 | + reward = self.discriminator.forward(traj) |
| 148 | + reward = - reward.log() |
| 149 | + return reward.detach().item() |
| 150 | + |
| 151 | + def run(self): |
| 152 | + for i in range(self.episode): |
| 153 | + obs = self.env.reset() |
| 154 | + if self.render: |
| 155 | + self.env.render() |
| 156 | + total_reward = 0 |
| 157 | + total_custom_reward = 0 |
| 158 | + while True: |
| 159 | + action = self.policy_net.act(torch.FloatTensor(np.expand_dims(obs, 0))) |
| 160 | + if not self.is_disc: |
| 161 | + action = [action] |
| 162 | + next_obs, reward, done, _ = self.env.step(action) |
| 163 | + custom_reward = self.get_reward(obs, action) |
| 164 | + value = self.value_net.forward(torch.FloatTensor(np.expand_dims(obs, 0))).detach().item() |
| 165 | + self.buffer.store(obs, action, custom_reward, done, value) |
| 166 | + total_reward += reward |
| 167 | + total_custom_reward += custom_reward |
| 168 | + obs = next_obs |
| 169 | + if self.render: |
| 170 | + self.env.render() |
| 171 | + |
| 172 | + if done: |
| 173 | + if not self.weight_reward: |
| 174 | + self.weight_reward = total_reward |
| 175 | + else: |
| 176 | + self.weight_reward = 0.99 * self.weight_reward + 0.01 * total_reward |
| 177 | + if not self.weight_custom_reward: |
| 178 | + self.weight_custom_reward = total_custom_reward |
| 179 | + else: |
| 180 | + self.weight_custom_reward = 0.99 * self.weight_custom_reward + 0.01 * total_custom_reward |
| 181 | + if len(self.buffer) >= self.train_iter: |
| 182 | + self.buffer.process() |
| 183 | + self.discriminator_train() |
| 184 | + self.ppo_train() |
| 185 | + self.buffer.clear() |
| 186 | + print('episode: {} reward: {:.2f} custom_reward: {:.3f} weight_reward: {:.2f} weight_custom_reward: {:.4f}'.format(i + 1, total_reward, total_custom_reward, self.weight_reward, self.weight_custom_reward)) |
| 187 | + break |
0 commit comments