forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrl-dqn.rs
64 lines (49 loc) · 2.01 KB
/
rl-dqn.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//! Implements Deep Q Learning on random data.
use dfdx::prelude::*;
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::time::Instant;
const STATE_SIZE: usize = 4;
const ACTION_SIZE: usize = 2;
// our simple 2 layer feedforward network with ReLU activations
type QNetwork = (
(Linear<STATE_SIZE, 32>, ReLU),
(Linear<32, 32>, ReLU),
Linear<32, ACTION_SIZE>,
);
fn main() {
let mut rng = StdRng::seed_from_u64(0);
let state: Tensor2D<64, STATE_SIZE> = Tensor2D::randn(&mut rng);
let action: [usize; 64] = [(); 64].map(|_| rng.gen_range(0..ACTION_SIZE));
let reward: Tensor1D<64> = Tensor1D::randn(&mut rng);
let done: Tensor1D<64> = Tensor1D::zeros();
let next_state: Tensor2D<64, STATE_SIZE> = Tensor2D::randn(&mut rng);
// initiliaze model - all weights are 0s
let mut q_net: QNetwork = Default::default();
q_net.reset_params(&mut rng);
let target_q_net: QNetwork = q_net.clone();
let mut sgd = Sgd::new(SgdConfig {
lr: 1e-1,
momentum: Some(Momentum::Nesterov(0.9)),
weight_decay: None,
});
// run through training data
for _i_epoch in 0..15 {
let start = Instant::now();
// targ_q = R + discount * max(Q(S'))
// curr_q = Q(S)[A]
// loss = mse(curr_q, targ_q)
let next_q_values: Tensor2D<64, ACTION_SIZE> = target_q_net.forward(next_state.clone());
let max_next_q: Tensor1D<64> = next_q_values.max();
let target_q = 0.99 * mul(max_next_q, 1.0 - done.clone()) + reward.clone();
// forward through model, computing gradients
let q_values = q_net.forward(state.trace());
let action_qs: Tensor1D<64, OwnedTape> = q_values.select(&action);
let loss = mse_loss(action_qs, target_q);
let loss_v = *loss.data();
// run backprop
let gradients = loss.backward();
// update weights with optimizer
sgd.update(&mut q_net, gradients).expect("Unused params");
println!("q loss={:#.3} in {:?}", loss_v, start.elapsed());
}
}