-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ppo.py
29 lines (22 loc) · 929 Bytes
/
train_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gym
import numpy as np
from configs.ppo.config import get_config
from models.ppo import PPO
from buffer import OnlineReplayBuffer
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from trainer.ppo import PPOTrainer
from utils import netset_randomness
netset_randomness(42)
args = get_config()
env = gym.make(args.env_name)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
print(f"Observation space: {np.prod(obs_dim)}, Action space: {np.prod(act_dim)}")
buffer = OnlineReplayBuffer(obs_dim, act_dim, args.steps_per_epoch, args.gamma, args.lam)
ppo = PPO(obs_dim=np.prod(obs_dim), act_dim=np.prod(act_dim), args=args)
pi_optimizer = Adam(ppo.pi.parameters(), lr=args.pi_lr)
vf_optimizer = Adam(ppo.v.parameters(), lr=args.vf_lr)
writer = SummaryWriter()
trainer = PPOTrainer(lambda: env, ppo, args, buffer, writer, pi_optimizer, vf_optimizer)
trainer.train()