Skip to content

Commit 2cd915d

Browse files
authored
Add Adam optimizer (#590)
1 parent 5cd1c18 commit 2cd915d

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

src/machine_learning/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pub use self::loss_function::mae_loss;
88
pub use self::loss_function::mse_loss;
99

1010
pub use self::optimization::gradient_descent;
11+
pub use self::optimization::Adam;
+288
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
//! # Adam (Adaptive Moment Estimation) optimizer
2+
//!
3+
//! The `Adam (Adaptive Moment Estimation)` optimizer is an adaptive learning rate algorithm used
4+
//! in gradient descent and machine learning, such as for training neural networks to solve deep
5+
//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively
6+
//! updates learning rates individually for each model parameter based on the gradient history.
7+
//!
8+
//! ## Algorithm:
9+
//!
10+
//! Given:
11+
//! - α is the learning rate
12+
//! - (β_1, β_2) are the exponential decay rates for moment estimates
13+
//! - ϵ is any small value to prevent division by zero
14+
//! - g_t are the gradients at time step t
15+
//! - m_t are the biased first moment estimates of the gradient at time step t
16+
//! - v_t are the biased second raw moment estimates of the gradient at time step t
17+
//! - θ_t are the model parameters at time step t
18+
//! - t is the time step
19+
//!
20+
//! Required:
21+
//! θ_0
22+
//!
23+
//! Initialize:
24+
//! m_0 <- 0
25+
//! v_0 <- 0
26+
//! t <- 0
27+
//!
28+
//! while θ_t not converged do
29+
//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t
30+
//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2
31+
//! m_hat_t = m_t / 1 - β_1^t
32+
//! v_hat_t = v_t / 1 - β_2^t
33+
//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ)
34+
//!
35+
//! ## Resources:
36+
//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba):
37+
//! - [https://arxiv.org/abs/1412.6980]
38+
//! - PyTorch Adam optimizer:
39+
//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam]
40+
//!
41+
pub struct Adam {
42+
learning_rate: f64, // alpha: initial step size for iterative optimization
43+
betas: (f64, f64), // betas: exponential decay rates for moment estimates
44+
epsilon: f64, // epsilon: prevent division by zero
45+
m: Vec<f64>, // m: biased first moment estimate of the gradient vector
46+
v: Vec<f64>, // v: biased second raw moment estimate of the gradient vector
47+
t: usize, // t: time step
48+
}
49+
50+
impl Adam {
51+
pub fn new(
52+
learning_rate: Option<f64>,
53+
betas: Option<(f64, f64)>,
54+
epsilon: Option<f64>,
55+
params_len: usize,
56+
) -> Self {
57+
Adam {
58+
learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr
59+
betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates
60+
epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon
61+
m: vec![0.0; params_len], // first moment vector elements all initialized to zero
62+
v: vec![0.0; params_len], // second moment vector elements all initialized to zero
63+
t: 0, // time step initialized to zero
64+
}
65+
}
66+
67+
pub fn step(&mut self, gradients: &Vec<f64>) -> Vec<f64> {
68+
let mut model_params = vec![0.0; gradients.len()];
69+
self.t += 1;
70+
71+
for i in 0..gradients.len() {
72+
// update biased first moment estimate and second raw moment estimate
73+
self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i];
74+
self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powf(2f64);
75+
76+
// compute bias-corrected first moment estimate and second raw moment estimate
77+
let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
78+
let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
79+
80+
// update model parameters
81+
model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
82+
}
83+
model_params // return updated model parameters
84+
}
85+
}
86+
87+
#[cfg(test)]
88+
mod tests {
89+
use super::*;
90+
91+
#[test]
92+
fn test_adam_init_default_values() {
93+
let optimizer = Adam::new(None, None, None, 1);
94+
95+
assert_eq!(optimizer.learning_rate, 0.001);
96+
assert_eq!(optimizer.betas, (0.9, 0.999));
97+
assert_eq!(optimizer.epsilon, 1e-8);
98+
assert_eq!(optimizer.m, vec![0.0; 1]);
99+
assert_eq!(optimizer.v, vec![0.0; 1]);
100+
assert_eq!(optimizer.t, 0);
101+
}
102+
103+
#[test]
104+
fn test_adam_init_custom_lr_value() {
105+
let optimizer = Adam::new(Some(0.9), None, None, 2);
106+
107+
assert_eq!(optimizer.learning_rate, 0.9);
108+
assert_eq!(optimizer.betas, (0.9, 0.999));
109+
assert_eq!(optimizer.epsilon, 1e-8);
110+
assert_eq!(optimizer.m, vec![0.0; 2]);
111+
assert_eq!(optimizer.v, vec![0.0; 2]);
112+
assert_eq!(optimizer.t, 0);
113+
}
114+
115+
#[test]
116+
fn test_adam_init_custom_betas_value() {
117+
let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3);
118+
119+
assert_eq!(optimizer.learning_rate, 0.001);
120+
assert_eq!(optimizer.betas, (0.8, 0.899));
121+
assert_eq!(optimizer.epsilon, 1e-8);
122+
assert_eq!(optimizer.m, vec![0.0; 3]);
123+
assert_eq!(optimizer.v, vec![0.0; 3]);
124+
assert_eq!(optimizer.t, 0);
125+
}
126+
127+
#[test]
128+
fn test_adam_init_custom_epsilon_value() {
129+
let optimizer = Adam::new(None, None, Some(1e-10), 4);
130+
131+
assert_eq!(optimizer.learning_rate, 0.001);
132+
assert_eq!(optimizer.betas, (0.9, 0.999));
133+
assert_eq!(optimizer.epsilon, 1e-10);
134+
assert_eq!(optimizer.m, vec![0.0; 4]);
135+
assert_eq!(optimizer.v, vec![0.0; 4]);
136+
assert_eq!(optimizer.t, 0);
137+
}
138+
139+
#[test]
140+
fn test_adam_init_all_custom_values() {
141+
let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5);
142+
143+
assert_eq!(optimizer.learning_rate, 1.0);
144+
assert_eq!(optimizer.betas, (0.001, 0.099));
145+
assert_eq!(optimizer.epsilon, 1e-1);
146+
assert_eq!(optimizer.m, vec![0.0; 5]);
147+
assert_eq!(optimizer.v, vec![0.0; 5]);
148+
assert_eq!(optimizer.t, 0);
149+
}
150+
151+
#[test]
152+
fn test_adam_step_default_params() {
153+
let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
154+
155+
let mut optimizer = Adam::new(None, None, None, 8);
156+
let updated_params = optimizer.step(&gradients);
157+
158+
assert_eq!(
159+
updated_params,
160+
vec![
161+
0.0009999999900000003,
162+
-0.000999999995,
163+
0.0009999999966666666,
164+
-0.0009999999975,
165+
0.000999999998,
166+
-0.0009999999983333334,
167+
0.0009999999985714286,
168+
-0.00099999999875
169+
]
170+
);
171+
}
172+
173+
#[test]
174+
fn test_adam_step_custom_params() {
175+
let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0];
176+
177+
let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9);
178+
let updated_params = optimizer.step(&gradients);
179+
180+
assert_eq!(
181+
updated_params,
182+
vec![
183+
-0.004999994444450618,
184+
0.004999993750007813,
185+
-0.004999992857153062,
186+
0.004999991666680556,
187+
-0.004999990000020001,
188+
0.004999987500031251,
189+
-0.004999983333388888,
190+
0.004999975000124999,
191+
-0.0049999500004999945
192+
]
193+
);
194+
}
195+
196+
#[test]
197+
fn test_adam_step_empty_gradients_array() {
198+
let gradients = vec![];
199+
200+
let mut optimizer = Adam::new(None, None, None, 0);
201+
let updated_params = optimizer.step(&gradients);
202+
203+
assert_eq!(updated_params, vec![]);
204+
}
205+
206+
#[ignore]
207+
#[test]
208+
fn test_adam_step_iteratively_until_convergence_with_default_params() {
209+
const CONVERGENCE_THRESHOLD: f64 = 1e-5;
210+
let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
211+
212+
let mut optimizer = Adam::new(None, None, None, 6);
213+
214+
let mut model_params = vec![0.0; 6];
215+
let mut updated_params = optimizer.step(&gradients);
216+
217+
while (updated_params
218+
.iter()
219+
.zip(model_params.iter())
220+
.map(|(x, y)| x - y)
221+
.collect::<Vec<f64>>())
222+
.iter()
223+
.map(|&x| x.powi(2))
224+
.sum::<f64>()
225+
.sqrt()
226+
> CONVERGENCE_THRESHOLD
227+
{
228+
model_params = updated_params;
229+
updated_params = optimizer.step(&gradients);
230+
}
231+
232+
assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]);
233+
assert_ne!(updated_params, model_params);
234+
assert_eq!(
235+
updated_params,
236+
vec![
237+
-0.0009999999899999931,
238+
-0.0009999999949999929,
239+
-0.0009999999966666597,
240+
-0.0009999999974999929,
241+
-0.0009999999979999927,
242+
-0.0009999999983333263
243+
]
244+
);
245+
}
246+
247+
#[ignore]
248+
#[test]
249+
fn test_adam_step_iteratively_until_convergence_with_custom_params() {
250+
const CONVERGENCE_THRESHOLD: f64 = 1e-7;
251+
let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0];
252+
253+
let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7);
254+
255+
let mut model_params = vec![0.0; 7];
256+
let mut updated_params = optimizer.step(&gradients);
257+
258+
while (updated_params
259+
.iter()
260+
.zip(model_params.iter())
261+
.map(|(x, y)| x - y)
262+
.collect::<Vec<f64>>())
263+
.iter()
264+
.map(|&x| x.powi(2))
265+
.sum::<f64>()
266+
.sqrt()
267+
> CONVERGENCE_THRESHOLD
268+
{
269+
model_params = updated_params;
270+
updated_params = optimizer.step(&gradients);
271+
}
272+
273+
assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]);
274+
assert_ne!(updated_params, model_params);
275+
assert_eq!(
276+
updated_params,
277+
vec![
278+
-0.004999992857153061,
279+
0.004999993750007814,
280+
-0.0049999944444506185,
281+
0.004999995000005001,
282+
-0.004999995454549587,
283+
0.004999995833336807,
284+
-0.004999996153849113
285+
]
286+
);
287+
}
288+
}
+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
mod gradient_descent;
22

3+
mod adam;
4+
5+
pub use self::adam::Adam;
36
pub use self::gradient_descent::gradient_descent;

0 commit comments

Comments
 (0)