forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrl-ppo.rs
67 lines (49 loc) · 2.26 KB
/
rl-ppo.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
//! Implements the reinforcement learning algorithm Proximal Policy Optimization (PPO) on random data.
use dfdx::prelude::*;
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::time::Instant;
const STATE_SIZE: usize = 4;
const ACTION_SIZE: usize = 2;
type PolicyNetwork = (
(Linear<STATE_SIZE, 32>, ReLU),
(Linear<32, 32>, ReLU),
Linear<32, ACTION_SIZE>,
);
fn main() {
let mut rng = StdRng::seed_from_u64(0);
let state: Tensor2D<64, STATE_SIZE> = Tensor2D::randn(&mut rng);
let action: [usize; 64] = [(); 64].map(|_| rng.gen_range(0..ACTION_SIZE));
let advantage: Tensor1D<64> = Tensor1D::randn(&mut rng);
// initiliaze model - all weights are 0s
let mut pi_net: PolicyNetwork = Default::default();
pi_net.reset_params(&mut rng);
let target_pi_net: PolicyNetwork = pi_net.clone();
let mut sgd = Sgd::new(SgdConfig {
lr: 1e-1,
momentum: Some(Momentum::Nesterov(0.9)),
weight_decay: None,
});
// run through training data
for _i_epoch in 0..15 {
let start = Instant::now();
// old_log_prob_a = log(P(action | state, target_pi_net))
let old_logits = target_pi_net.forward(state.clone());
let old_log_prob_a: Tensor1D<64> = old_logits.log_softmax::<Axis<1>>().select(&action);
// log_prob_a = log(P(action | state, pi_net))
let logits = pi_net.forward(state.trace());
let log_prob_a: Tensor1D<64, OwnedTape> = logits.log_softmax::<Axis<1>>().select(&action);
// ratio = P(action | state, pi_net) / P(action | state, target_pi_net)
// but compute in log space and then do .exp() to bring it back out of log space
let ratio = (log_prob_a - old_log_prob_a).exp();
// because we need to re-use `ratio` a 2nd time, we need to do some tape manipulation here.
let surr1 = ratio.with_empty_tape() * advantage.clone();
let surr2 = ratio.clamp(0.8, 1.2) * advantage.clone();
let ppo_loss = -(minimum(surr2, surr1).mean());
let loss_v = *ppo_loss.data();
// run backprop
let gradients = ppo_loss.backward();
// update weights with optimizer
sgd.update(&mut pi_net, gradients).expect("Unused params");
println!("loss={:#} in {:?}", loss_v, start.elapsed());
}
}