diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index c7a76fbd7a..c50004cef1 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -53,7 +53,7 @@ pub use layer_norm::{ }; pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; -pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; +pub use optim::{AdamW, Optimizer, ParamsAdamW, ParamsSGD, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 2c671fc59e..ff53b8eee8 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -28,54 +28,135 @@ pub trait Optimizer: Sized { } } +#[derive(Clone, Debug)] +pub struct ParamsSGD { + pub lr: f64, + pub momentum: Option, + pub nesterov: bool, +} + +impl Default for ParamsSGD { + fn default() -> Self { + Self { + lr: 0.01, + momentum: None, + nesterov: false, + } + } +} + +#[derive(Debug)] +struct VarSGD { + var: Var, + velocity: Option, +} + /// Optimizer for Stochastic Gradient Descent. +/// By Default,the update rule of SGD is +/// ```tex +/// \theta_{t+1} = \theta_{t} - \eta \cdot g_{t} +/// ``` +/// # momentum +/// Momentum accumulates a moving average of past gradients to accelerate +/// updates in consistent directions and dampen oscillations. /// -/// Contrary to the PyTorch implementation of SGD, this version does not support momentum. +/// you can specify the momentum by `momentum = Some(mu)` +/// ```tex +/// v_t = \mu * v_{t-1} + g_t +/// \theta_t = \theta_{t-1} - lr * v_t +/// ``` +/// # Nesterov +/// Nesterov momentum improves upon standard momentum by computing the gradient +/// at a predicted future position, rather than the current position. +/// +/// you can specify the momentum by `nesterov = true` +/// ```tex +/// v_t = \mu * v_{t-1} + g_t +/// \theta_t = \theta_{t-1} - lr * (g_t + \mu * v_{t-1}) +/// ``` #[derive(Debug)] pub struct SGD { - vars: Vec, - learning_rate: f64, + vars: Vec, + params: ParamsSGD, } impl Optimizer for SGD { - type Config = f64; + type Config = ParamsSGD; - fn new(vars: Vec, learning_rate: f64) -> Result { + fn new(vars: Vec, params: ParamsSGD) -> Result { let vars = vars .into_iter() .filter(|var| var.dtype().is_float()) - .collect(); - Ok(Self { - vars, - learning_rate, - }) + .map(|v| { + let velocity = params + .momentum + .map(|_| Var::zeros(v.shape(), v.dtype(), v.device())) + .transpose()?; + Ok(VarSGD { var: v, velocity }) + }) + .collect::>>()?; + Ok(Self { vars, params }) } fn learning_rate(&self) -> f64 { - self.learning_rate + self.params.lr + } + + fn set_learning_rate(&mut self, lr: f64) { + self.params.lr = lr } fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> { - for var in self.vars.iter() { - if let Some(grad) = grads.get(var) { - var.set(&var.sub(&(grad * self.learning_rate)?)?)?; + let lr = self.params.lr; + for var in self.vars.iter_mut() { + let theta = &var.var; + if let Some(g) = grads.get(theta) { + match (&mut var.velocity, self.params.momentum) { + (None, None) => { + let next = theta.sub(&(g * lr)?)?; + theta.set(&next)?; + } + (Some(v), Some(mu)) => { + let buf = v.as_tensor(); + let next_v = ((buf * mu)? + g)?; + v.set(&next_v)?; + + let update = if self.params.nesterov { + (g + buf * mu)? + } else { + next_v + }; + + let next_theta = theta.sub(&(update * lr)?)?; + theta.set(&next_theta)?; + } + _ => { + unreachable!("velocity and momentum must be consistent") + } + } } } Ok(()) } - - fn set_learning_rate(&mut self, lr: f64) { - self.learning_rate = lr - } } impl SGD { pub fn into_inner(self) -> Vec { - self.vars + self.vars.into_iter().map(|v| v.var).collect() } - pub fn push(&mut self, var: &Var) { - self.vars.push(var.clone()) + pub fn push(&mut self, var: &Var) -> Result<()> { + let velocity = self + .params + .momentum + .map(|_| Var::zeros(var.shape(), var.dtype(), var.device())) + .transpose()?; + + self.vars.push(VarSGD { + var: var.clone(), + velocity, + }); + Ok(()) } } diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 4eb14ed81b..d29178bd37 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -8,12 +8,16 @@ use candle::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle::{DType, Device, Tensor, Var}; -use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; +use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, ParamsSGD, SGD}; #[test] -fn sgd_optim() -> Result<()> { +fn sgd_optim_default() -> Result<()> { let x = Var::new(0f32, &Device::Cpu)?; - let mut sgd = SGD::new(vec![x.clone()], 0.1)?; + let params = ParamsSGD { + lr: 0.1, + ..Default::default() + }; + let mut sgd = SGD::new(vec![x.clone()], params)?; let xt = x.as_tensor(); for _step in 0..100 { let loss = ((xt - 4.2)? * (xt - 4.2)?)?; @@ -59,7 +63,11 @@ fn sgd_linear_regression() -> Result<()> { // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; - let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?; + let params = ParamsSGD { + lr: 0.004, + ..Default::default() + }; + let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..1000 { let ys = lin.forward(&sample_xs)?; @@ -71,6 +79,112 @@ fn sgd_linear_regression() -> Result<()> { Ok(()) } +/* The results of this test have been checked against the following PyTorch code. + import torch + from torch import optim + + w_gen = torch.tensor([[3., 1.]]) + b_gen = torch.tensor([-2.]) + + sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) + sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + + m = torch.nn.Linear(2, 1) + with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() + optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.3345) + for _step in range(100): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() + print(m.weight) + print(m.bias) +*/ +#[test] +fn sgd_optim_momentum() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let params = ParamsSGD { + lr: 0.004, + momentum: Some(0.3345), + nesterov: false, + }; + let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + sgd.backward_step(&loss)?; + } + assert_eq!(w.to_vec2::()?, &[[2.9061165, 0.8827776]]); + assert_eq!(b.to_scalar::()?, -0.865149); + Ok(()) +} + +/* The results of this test have been checked against the following PyTorch code. + import torch + from torch import optim + + w_gen = torch.tensor([[3., 1.]]) + b_gen = torch.tensor([-2.]) + + sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) + sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + + m = torch.nn.Linear(2, 1) + with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() + optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.833,nesterov=True) + for _step in range(100): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() + print(m.weight) + print(m.bias) +*/ +#[test] +fn sgd_optim_momentum_nesterov() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let params = ParamsSGD { + lr: 0.004, + momentum: Some(0.833), + nesterov: true, + }; + let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + sgd.backward_step(&loss)?; + } + assert_eq!(w.to_vec2::()?, &[[-9.62991e27, -5.7198134e28]]); + assert_eq!(b.to_scalar::()?, -6.704839e27); + Ok(()) +} + /* The following test returns the same values as the PyTorch code below. import torch from torch import optim @@ -94,7 +208,7 @@ for _step in range(100): optimizer.step() print(m.weight) print(m.bias) -*/ + */ #[test] fn adamw_linear_regression() -> Result<()> { let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;