forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path05-optim.rs
63 lines (52 loc) · 2.11 KB
/
05-optim.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
//! Intro to dfdx::optim
use rand::prelude::*;
use dfdx::arrays::HasArrayData;
use dfdx::gradients::{Gradients, OwnedTape};
use dfdx::losses::mse_loss;
use dfdx::nn::{Linear, ModuleMut, ReLU, ResetParams, Tanh};
use dfdx::optim::{Momentum, Optimizer, Sgd, SgdConfig};
use dfdx::tensor::{Tensor2D, TensorCreator};
// first let's declare our neural network to optimze
type Mlp = (
(Linear<5, 32>, ReLU),
(Linear<32, 32>, ReLU),
(Linear<32, 2>, Tanh),
);
fn main() {
let mut rng = StdRng::seed_from_u64(0);
// The first step to optimizing is to initialize the optimizer.
// Here we construct a stochastic gradient descent optimizer
// for our Mlp.
let mut sgd: Sgd<Mlp> = Sgd::new(SgdConfig {
lr: 1e-1,
momentum: Some(Momentum::Nesterov(0.9)),
weight_decay: None,
});
// let's initialize our model and some dummy data
let mut mlp: Mlp = Default::default();
mlp.reset_params(&mut rng);
let x: Tensor2D<3, 5> = TensorCreator::randn(&mut rng);
let y: Tensor2D<3, 2> = TensorCreator::randn(&mut rng);
// first we pass our gradient tracing input through the network.
// since we are training, we use forward_mut() instead of forward
let prediction: Tensor2D<3, 2, OwnedTape> = mlp.forward_mut(x.trace());
// next compute the loss against the target dummy data
let loss = mse_loss(prediction, y.clone());
dbg!(loss.data());
// extract the gradients
let gradients: Gradients = loss.backward();
// the final step is to use our optimizer to update our model
// given the gradients we've calculated.
// This will modify our model!
sgd.update(&mut mlp, gradients)
.expect("Oops, there were some unused params");
// let's do this a couple times to make sure the loss decreases!
for i in 0..5 {
let prediction = mlp.forward_mut(x.trace());
let loss = mse_loss(prediction, y.clone());
println!("Loss after update {i}: {:?}", loss.data());
let gradients: Gradients = loss.backward();
sgd.update(&mut mlp, gradients)
.expect("Oops, there were some unused params");
}
}