|
| 1 | +//! # Adam (Adaptive Moment Estimation) optimizer |
| 2 | +//! |
| 3 | +//! The `Adam (Adaptive Moment Estimation)` optimizer is an adaptive learning rate algorithm used |
| 4 | +//! in gradient descent and machine learning, such as for training neural networks to solve deep |
| 5 | +//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively |
| 6 | +//! updates learning rates individually for each model parameter based on the gradient history. |
| 7 | +//! |
| 8 | +//! ## Algorithm: |
| 9 | +//! |
| 10 | +//! Given: |
| 11 | +//! - α is the learning rate |
| 12 | +//! - (β_1, β_2) are the exponential decay rates for moment estimates |
| 13 | +//! - ϵ is any small value to prevent division by zero |
| 14 | +//! - g_t are the gradients at time step t |
| 15 | +//! - m_t are the biased first moment estimates of the gradient at time step t |
| 16 | +//! - v_t are the biased second raw moment estimates of the gradient at time step t |
| 17 | +//! - θ_t are the model parameters at time step t |
| 18 | +//! - t is the time step |
| 19 | +//! |
| 20 | +//! Required: |
| 21 | +//! θ_0 |
| 22 | +//! |
| 23 | +//! Initialize: |
| 24 | +//! m_0 <- 0 |
| 25 | +//! v_0 <- 0 |
| 26 | +//! t <- 0 |
| 27 | +//! |
| 28 | +//! while θ_t not converged do |
| 29 | +//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t |
| 30 | +//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2 |
| 31 | +//! m_hat_t = m_t / 1 - β_1^t |
| 32 | +//! v_hat_t = v_t / 1 - β_2^t |
| 33 | +//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ) |
| 34 | +//! |
| 35 | +//! ## Resources: |
| 36 | +//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba): |
| 37 | +//! - [https://arxiv.org/abs/1412.6980] |
| 38 | +//! - PyTorch Adam optimizer: |
| 39 | +//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam] |
| 40 | +//! |
| 41 | +pub struct Adam { |
| 42 | + learning_rate: f64, // alpha: initial step size for iterative optimization |
| 43 | + betas: (f64, f64), // betas: exponential decay rates for moment estimates |
| 44 | + epsilon: f64, // epsilon: prevent division by zero |
| 45 | + m: Vec<f64>, // m: biased first moment estimate of the gradient vector |
| 46 | + v: Vec<f64>, // v: biased second raw moment estimate of the gradient vector |
| 47 | + t: usize, // t: time step |
| 48 | +} |
| 49 | + |
| 50 | +impl Adam { |
| 51 | + pub fn new( |
| 52 | + learning_rate: Option<f64>, |
| 53 | + betas: Option<(f64, f64)>, |
| 54 | + epsilon: Option<f64>, |
| 55 | + params_len: usize, |
| 56 | + ) -> Self { |
| 57 | + Adam { |
| 58 | + learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr |
| 59 | + betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates |
| 60 | + epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon |
| 61 | + m: vec![0.0; params_len], // first moment vector elements all initialized to zero |
| 62 | + v: vec![0.0; params_len], // second moment vector elements all initialized to zero |
| 63 | + t: 0, // time step initialized to zero |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + pub fn step(&mut self, gradients: &Vec<f64>) -> Vec<f64> { |
| 68 | + let mut model_params = vec![0.0; gradients.len()]; |
| 69 | + self.t += 1; |
| 70 | + |
| 71 | + for i in 0..gradients.len() { |
| 72 | + // update biased first moment estimate and second raw moment estimate |
| 73 | + self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i]; |
| 74 | + self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powf(2f64); |
| 75 | + |
| 76 | + // compute bias-corrected first moment estimate and second raw moment estimate |
| 77 | + let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32)); |
| 78 | + let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32)); |
| 79 | + |
| 80 | + // update model parameters |
| 81 | + model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon); |
| 82 | + } |
| 83 | + model_params // return updated model parameters |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +#[cfg(test)] |
| 88 | +mod tests { |
| 89 | + use super::*; |
| 90 | + |
| 91 | + #[test] |
| 92 | + fn test_adam_init_default_values() { |
| 93 | + let optimizer = Adam::new(None, None, None, 1); |
| 94 | + |
| 95 | + assert_eq!(optimizer.learning_rate, 0.001); |
| 96 | + assert_eq!(optimizer.betas, (0.9, 0.999)); |
| 97 | + assert_eq!(optimizer.epsilon, 1e-8); |
| 98 | + assert_eq!(optimizer.m, vec![0.0; 1]); |
| 99 | + assert_eq!(optimizer.v, vec![0.0; 1]); |
| 100 | + assert_eq!(optimizer.t, 0); |
| 101 | + } |
| 102 | + |
| 103 | + #[test] |
| 104 | + fn test_adam_init_custom_lr_value() { |
| 105 | + let optimizer = Adam::new(Some(0.9), None, None, 2); |
| 106 | + |
| 107 | + assert_eq!(optimizer.learning_rate, 0.9); |
| 108 | + assert_eq!(optimizer.betas, (0.9, 0.999)); |
| 109 | + assert_eq!(optimizer.epsilon, 1e-8); |
| 110 | + assert_eq!(optimizer.m, vec![0.0; 2]); |
| 111 | + assert_eq!(optimizer.v, vec![0.0; 2]); |
| 112 | + assert_eq!(optimizer.t, 0); |
| 113 | + } |
| 114 | + |
| 115 | + #[test] |
| 116 | + fn test_adam_init_custom_betas_value() { |
| 117 | + let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3); |
| 118 | + |
| 119 | + assert_eq!(optimizer.learning_rate, 0.001); |
| 120 | + assert_eq!(optimizer.betas, (0.8, 0.899)); |
| 121 | + assert_eq!(optimizer.epsilon, 1e-8); |
| 122 | + assert_eq!(optimizer.m, vec![0.0; 3]); |
| 123 | + assert_eq!(optimizer.v, vec![0.0; 3]); |
| 124 | + assert_eq!(optimizer.t, 0); |
| 125 | + } |
| 126 | + |
| 127 | + #[test] |
| 128 | + fn test_adam_init_custom_epsilon_value() { |
| 129 | + let optimizer = Adam::new(None, None, Some(1e-10), 4); |
| 130 | + |
| 131 | + assert_eq!(optimizer.learning_rate, 0.001); |
| 132 | + assert_eq!(optimizer.betas, (0.9, 0.999)); |
| 133 | + assert_eq!(optimizer.epsilon, 1e-10); |
| 134 | + assert_eq!(optimizer.m, vec![0.0; 4]); |
| 135 | + assert_eq!(optimizer.v, vec![0.0; 4]); |
| 136 | + assert_eq!(optimizer.t, 0); |
| 137 | + } |
| 138 | + |
| 139 | + #[test] |
| 140 | + fn test_adam_init_all_custom_values() { |
| 141 | + let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5); |
| 142 | + |
| 143 | + assert_eq!(optimizer.learning_rate, 1.0); |
| 144 | + assert_eq!(optimizer.betas, (0.001, 0.099)); |
| 145 | + assert_eq!(optimizer.epsilon, 1e-1); |
| 146 | + assert_eq!(optimizer.m, vec![0.0; 5]); |
| 147 | + assert_eq!(optimizer.v, vec![0.0; 5]); |
| 148 | + assert_eq!(optimizer.t, 0); |
| 149 | + } |
| 150 | + |
| 151 | + #[test] |
| 152 | + fn test_adam_step_default_params() { |
| 153 | + let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0]; |
| 154 | + |
| 155 | + let mut optimizer = Adam::new(None, None, None, 8); |
| 156 | + let updated_params = optimizer.step(&gradients); |
| 157 | + |
| 158 | + assert_eq!( |
| 159 | + updated_params, |
| 160 | + vec![ |
| 161 | + 0.0009999999900000003, |
| 162 | + -0.000999999995, |
| 163 | + 0.0009999999966666666, |
| 164 | + -0.0009999999975, |
| 165 | + 0.000999999998, |
| 166 | + -0.0009999999983333334, |
| 167 | + 0.0009999999985714286, |
| 168 | + -0.00099999999875 |
| 169 | + ] |
| 170 | + ); |
| 171 | + } |
| 172 | + |
| 173 | + #[test] |
| 174 | + fn test_adam_step_custom_params() { |
| 175 | + let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0]; |
| 176 | + |
| 177 | + let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9); |
| 178 | + let updated_params = optimizer.step(&gradients); |
| 179 | + |
| 180 | + assert_eq!( |
| 181 | + updated_params, |
| 182 | + vec![ |
| 183 | + -0.004999994444450618, |
| 184 | + 0.004999993750007813, |
| 185 | + -0.004999992857153062, |
| 186 | + 0.004999991666680556, |
| 187 | + -0.004999990000020001, |
| 188 | + 0.004999987500031251, |
| 189 | + -0.004999983333388888, |
| 190 | + 0.004999975000124999, |
| 191 | + -0.0049999500004999945 |
| 192 | + ] |
| 193 | + ); |
| 194 | + } |
| 195 | + |
| 196 | + #[test] |
| 197 | + fn test_adam_step_empty_gradients_array() { |
| 198 | + let gradients = vec![]; |
| 199 | + |
| 200 | + let mut optimizer = Adam::new(None, None, None, 0); |
| 201 | + let updated_params = optimizer.step(&gradients); |
| 202 | + |
| 203 | + assert_eq!(updated_params, vec![]); |
| 204 | + } |
| 205 | + |
| 206 | + #[ignore] |
| 207 | + #[test] |
| 208 | + fn test_adam_step_iteratively_until_convergence_with_default_params() { |
| 209 | + const CONVERGENCE_THRESHOLD: f64 = 1e-5; |
| 210 | + let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; |
| 211 | + |
| 212 | + let mut optimizer = Adam::new(None, None, None, 6); |
| 213 | + |
| 214 | + let mut model_params = vec![0.0; 6]; |
| 215 | + let mut updated_params = optimizer.step(&gradients); |
| 216 | + |
| 217 | + while (updated_params |
| 218 | + .iter() |
| 219 | + .zip(model_params.iter()) |
| 220 | + .map(|(x, y)| x - y) |
| 221 | + .collect::<Vec<f64>>()) |
| 222 | + .iter() |
| 223 | + .map(|&x| x.powi(2)) |
| 224 | + .sum::<f64>() |
| 225 | + .sqrt() |
| 226 | + > CONVERGENCE_THRESHOLD |
| 227 | + { |
| 228 | + model_params = updated_params; |
| 229 | + updated_params = optimizer.step(&gradients); |
| 230 | + } |
| 231 | + |
| 232 | + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]); |
| 233 | + assert_ne!(updated_params, model_params); |
| 234 | + assert_eq!( |
| 235 | + updated_params, |
| 236 | + vec![ |
| 237 | + -0.0009999999899999931, |
| 238 | + -0.0009999999949999929, |
| 239 | + -0.0009999999966666597, |
| 240 | + -0.0009999999974999929, |
| 241 | + -0.0009999999979999927, |
| 242 | + -0.0009999999983333263 |
| 243 | + ] |
| 244 | + ); |
| 245 | + } |
| 246 | + |
| 247 | + #[ignore] |
| 248 | + #[test] |
| 249 | + fn test_adam_step_iteratively_until_convergence_with_custom_params() { |
| 250 | + const CONVERGENCE_THRESHOLD: f64 = 1e-7; |
| 251 | + let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0]; |
| 252 | + |
| 253 | + let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7); |
| 254 | + |
| 255 | + let mut model_params = vec![0.0; 7]; |
| 256 | + let mut updated_params = optimizer.step(&gradients); |
| 257 | + |
| 258 | + while (updated_params |
| 259 | + .iter() |
| 260 | + .zip(model_params.iter()) |
| 261 | + .map(|(x, y)| x - y) |
| 262 | + .collect::<Vec<f64>>()) |
| 263 | + .iter() |
| 264 | + .map(|&x| x.powi(2)) |
| 265 | + .sum::<f64>() |
| 266 | + .sqrt() |
| 267 | + > CONVERGENCE_THRESHOLD |
| 268 | + { |
| 269 | + model_params = updated_params; |
| 270 | + updated_params = optimizer.step(&gradients); |
| 271 | + } |
| 272 | + |
| 273 | + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]); |
| 274 | + assert_ne!(updated_params, model_params); |
| 275 | + assert_eq!( |
| 276 | + updated_params, |
| 277 | + vec![ |
| 278 | + -0.004999992857153061, |
| 279 | + 0.004999993750007814, |
| 280 | + -0.0049999944444506185, |
| 281 | + 0.004999995000005001, |
| 282 | + -0.004999995454549587, |
| 283 | + 0.004999995833336807, |
| 284 | + -0.004999996153849113 |
| 285 | + ] |
| 286 | + ); |
| 287 | + } |
| 288 | +} |
0 commit comments