Skip to content

Commit ecafde6

Browse files
Add gradient_descent algorithm (#580)
1 parent 5a5c0d4 commit ecafde6

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

src/machine_learning/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
mod linear_regression;
2-
pub use linear_regression::linear_regression;
2+
mod optimization;
3+
4+
pub use self::linear_regression::linear_regression;
5+
pub use self::optimization::gradient_descent;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/// Gradient Descent Optimization
2+
///
3+
/// Gradient descent is an iterative optimization algorithm used to find the minimum of a function.
4+
/// It works by updating the parameters (in this case, elements of the vector `x`) in the direction of
5+
/// the steepest decrease in the function's value. This is achieved by subtracting the gradient of
6+
/// the function at the current point from the current point. The learning rate controls the step size.
7+
///
8+
/// The equation for a single parameter (univariate) is:
9+
/// x_{k+1} = x_k - learning_rate * derivative_of_function(x_k)
10+
///
11+
/// For multivariate functions, it extends to each parameter:
12+
/// x_{k+1} = x_k - learning_rate * gradient_of_function(x_k)
13+
///
14+
/// # Arguments
15+
///
16+
/// * `derivative_fn` - The function that calculates the gradient of the objective function at a given point.
17+
/// * `x` - The initial parameter vector to be optimized.
18+
/// * `learning_rate` - Step size for each iteration.
19+
/// * `num_iterations` - The number of iterations to run the optimization.
20+
///
21+
/// # Returns
22+
///
23+
/// A reference to the optimized parameter vector `x`.
24+
25+
pub fn gradient_descent(
26+
derivative_fn: fn(&[f64]) -> Vec<f64>,
27+
x: &mut Vec<f64>,
28+
learning_rate: f64,
29+
num_iterations: i32,
30+
) -> &mut Vec<f64> {
31+
for _ in 0..num_iterations {
32+
let gradient = derivative_fn(x);
33+
for (x_k, grad) in x.iter_mut().zip(gradient.iter()) {
34+
*x_k -= learning_rate * grad;
35+
}
36+
}
37+
38+
x
39+
}
40+
41+
#[cfg(test)]
42+
mod test {
43+
use super::*;
44+
45+
#[test]
46+
fn test_gradient_descent_optimized() {
47+
fn derivative_of_square(params: &[f64]) -> Vec<f64> {
48+
params.iter().map(|x| 2. * x).collect()
49+
}
50+
51+
let mut x: Vec<f64> = vec![5.0, 6.0];
52+
let learning_rate: f64 = 0.03;
53+
let num_iterations: i32 = 1000;
54+
55+
let minimized_vector =
56+
gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations);
57+
58+
let test_vector = [0.0, 0.0];
59+
60+
let tolerance = 1e-6;
61+
for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) {
62+
assert!((minimized_value - test_value).abs() < tolerance);
63+
}
64+
}
65+
66+
#[test]
67+
fn test_gradient_descent_unoptimized() {
68+
fn derivative_of_square(params: &[f64]) -> Vec<f64> {
69+
params.iter().map(|x| 2. * x).collect()
70+
}
71+
72+
let mut x: Vec<f64> = vec![5.0, 6.0];
73+
let learning_rate: f64 = 0.03;
74+
let num_iterations: i32 = 10;
75+
76+
let minimized_vector =
77+
gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations);
78+
79+
let test_vector = [0.0, 0.0];
80+
81+
let tolerance = 1e-6;
82+
for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) {
83+
assert!((minimized_value - test_value).abs() >= tolerance);
84+
}
85+
}
86+
}
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod gradient_descent;
2+
3+
pub use self::gradient_descent::gradient_descent;

0 commit comments

Comments
 (0)