Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub use layer_norm::{
};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use optim::{AdamW, Optimizer, ParamsAdamW, ParamsSGD, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
Expand Down
123 changes: 102 additions & 21 deletions candle-nn/src/optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,54 +28,135 @@ pub trait Optimizer: Sized {
}
}

#[derive(Clone, Debug)]
pub struct ParamsSGD {
pub lr: f64,
pub momentum: Option<f64>,
pub nesterov: bool,
}

impl Default for ParamsSGD {
fn default() -> Self {
Self {
lr: 0.01,
momentum: None,
nesterov: false,
}
}
}

#[derive(Debug)]
struct VarSGD {
var: Var,
velocity: Option<Var>,
}

/// Optimizer for Stochastic Gradient Descent.
/// By Default,the update rule of SGD is
/// ```tex
/// \theta_{t+1} = \theta_{t} - \eta \cdot g_{t}
/// ```
/// # momentum
/// Momentum accumulates a moving average of past gradients to accelerate
/// updates in consistent directions and dampen oscillations.
///
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
/// you can specify the momentum by `momentum = Some(mu)`
/// ```tex
/// v_t = \mu * v_{t-1} + g_t
/// \theta_t = \theta_{t-1} - lr * v_t
/// ```
/// # Nesterov
/// Nesterov momentum improves upon standard momentum by computing the gradient
/// at a predicted future position, rather than the current position.
///
/// you can specify the momentum by `nesterov = true`
/// ```tex
/// v_t = \mu * v_{t-1} + g_t
/// \theta_t = \theta_{t-1} - lr * (g_t + \mu * v_{t-1})
/// ```
#[derive(Debug)]
pub struct SGD {
vars: Vec<Var>,
learning_rate: f64,
vars: Vec<VarSGD>,
params: ParamsSGD,
}

impl Optimizer for SGD {
type Config = f64;
type Config = ParamsSGD;

fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
fn new(vars: Vec<Var>, params: ParamsSGD) -> Result<Self> {
let vars = vars
.into_iter()
.filter(|var| var.dtype().is_float())
.collect();
Ok(Self {
vars,
learning_rate,
})
.map(|v| {
let velocity = params
.momentum
.map(|_| Var::zeros(v.shape(), v.dtype(), v.device()))
.transpose()?;
Ok(VarSGD { var: v, velocity })
})
.collect::<Result<Vec<_>>>()?;
Ok(Self { vars, params })
}

fn learning_rate(&self) -> f64 {
self.learning_rate
self.params.lr
}

fn set_learning_rate(&mut self, lr: f64) {
self.params.lr = lr
}

fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
let lr = self.params.lr;
for var in self.vars.iter_mut() {
let theta = &var.var;
if let Some(g) = grads.get(theta) {
match (&mut var.velocity, self.params.momentum) {
(None, None) => {
let next = theta.sub(&(g * lr)?)?;
theta.set(&next)?;
}
(Some(v), Some(mu)) => {
let buf = v.as_tensor();
let next_v = ((buf * mu)? + g)?;
v.set(&next_v)?;

let update = if self.params.nesterov {
(g + buf * mu)?
} else {
next_v
};

let next_theta = theta.sub(&(update * lr)?)?;
theta.set(&next_theta)?;
}
_ => {
unreachable!("velocity and momentum must be consistent")
}
}
}
}
Ok(())
}

fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr
}
}

impl SGD {
pub fn into_inner(self) -> Vec<Var> {
self.vars
self.vars.into_iter().map(|v| v.var).collect()
}

pub fn push(&mut self, var: &Var) {
self.vars.push(var.clone())
pub fn push(&mut self, var: &Var) -> Result<()> {
let velocity = self
.params
.momentum
.map(|_| Var::zeros(var.shape(), var.dtype(), var.device()))
.transpose()?;

self.vars.push(VarSGD {
var: var.clone(),
velocity,
});
Ok(())
}
}

Expand Down
124 changes: 119 additions & 5 deletions candle-nn/tests/optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};

use anyhow::Result;
use candle::{DType, Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, ParamsSGD, SGD};

#[test]
fn sgd_optim() -> Result<()> {
fn sgd_optim_default() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?;
let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
let params = ParamsSGD {
lr: 0.1,
..Default::default()
};
let mut sgd = SGD::new(vec![x.clone()], params)?;
let xt = x.as_tensor();
for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
Expand Down Expand Up @@ -59,7 +63,11 @@ fn sgd_linear_regression() -> Result<()> {
// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
let params = ParamsSGD {
lr: 0.004,
..Default::default()
};
let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..1000 {
let ys = lin.forward(&sample_xs)?;
Expand All @@ -71,6 +79,112 @@ fn sgd_linear_regression() -> Result<()> {
Ok(())
}

/* The results of this test have been checked against the following PyTorch code.
import torch
from torch import optim

w_gen = torch.tensor([[3., 1.]])
b_gen = torch.tensor([-2.])

sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen

m = torch.nn.Linear(2, 1)
with torch.no_grad():
m.weight.zero_()
m.bias.zero_()
optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.3345)
for _step in range(100):
optimizer.zero_grad()
ys = m(sample_xs)
loss = ((ys - sample_ys)**2).sum()
loss.backward()
optimizer.step()
print(m.weight)
print(m.bias)
*/
#[test]
fn sgd_optim_momentum() -> Result<()> {
// Generate some linear data, y = 3.x1 + x2 - 2.
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
let gen = Linear::new(w_gen, Some(b_gen));
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
let sample_ys = gen.forward(&sample_xs)?;

// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let params = ParamsSGD {
lr: 0.004,
momentum: Some(0.3345),
nesterov: false,
};
let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..100 {
let ys = lin.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
sgd.backward_step(&loss)?;
}
assert_eq!(w.to_vec2::<f32>()?, &[[2.9061165, 0.8827776]]);
assert_eq!(b.to_scalar::<f32>()?, -0.865149);
Ok(())
}

/* The results of this test have been checked against the following PyTorch code.
import torch
from torch import optim

w_gen = torch.tensor([[3., 1.]])
b_gen = torch.tensor([-2.])

sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen

m = torch.nn.Linear(2, 1)
with torch.no_grad():
m.weight.zero_()
m.bias.zero_()
optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.833,nesterov=True)
for _step in range(100):
optimizer.zero_grad()
ys = m(sample_xs)
loss = ((ys - sample_ys)**2).sum()
loss.backward()
optimizer.step()
print(m.weight)
print(m.bias)
*/
#[test]
fn sgd_optim_momentum_nesterov() -> Result<()> {
// Generate some linear data, y = 3.x1 + x2 - 2.
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
let gen = Linear::new(w_gen, Some(b_gen));
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
let sample_ys = gen.forward(&sample_xs)?;

// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let params = ParamsSGD {
lr: 0.004,
momentum: Some(0.833),
nesterov: true,
};
let mut sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..100 {
let ys = lin.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
sgd.backward_step(&loss)?;
}
assert_eq!(w.to_vec2::<f32>()?, &[[-9.62991e27, -5.7198134e28]]);
assert_eq!(b.to_scalar::<f32>()?, -6.704839e27);
Ok(())
}

/* The following test returns the same values as the PyTorch code below.
import torch
from torch import optim
Expand All @@ -94,7 +208,7 @@ for _step in range(100):
optimizer.step()
print(m.weight)
print(m.bias)
*/
*/
#[test]
fn adamw_linear_regression() -> Result<()> {
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
Expand Down