Skip to content

Commit 858443e

Browse files
authored
Add logistic regression & optimize the gradient descent algorithm (#832)
1 parent 5afbec4 commit 858443e

File tree

4 files changed

+96
-1
lines changed

4 files changed

+96
-1
lines changed

DIRECTORY.md

+1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156
* [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs)
157157
* [K Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs)
158158
* [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs)
159+
* [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs)
159160
* Loss Function
160161
* [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs)
161162
* [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use super::optimization::gradient_descent;
2+
use std::f64::consts::E;
3+
4+
/// Returns the wieghts after performing Logistic regression on the input data points.
5+
pub fn logistic_regression(
6+
data_points: Vec<(Vec<f64>, f64)>,
7+
iterations: usize,
8+
learning_rate: f64,
9+
) -> Option<Vec<f64>> {
10+
if data_points.is_empty() {
11+
return None;
12+
}
13+
14+
let num_features = data_points[0].0.len() + 1;
15+
let mut params = vec![0.0; num_features];
16+
17+
let derivative_fn = |params: &[f64]| derivative(params, &data_points);
18+
19+
gradient_descent(derivative_fn, &mut params, learning_rate, iterations as i32);
20+
21+
Some(params)
22+
}
23+
24+
fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
25+
let num_features = params.len();
26+
let mut gradients = vec![0.0; num_features];
27+
28+
for (features, y_i) in data_points {
29+
let z = params[0]
30+
+ params[1..]
31+
.iter()
32+
.zip(features)
33+
.map(|(p, x)| p * x)
34+
.sum::<f64>();
35+
let prediction = 1.0 / (1.0 + E.powf(-z));
36+
37+
gradients[0] += prediction - y_i;
38+
for (i, x_i) in features.iter().enumerate() {
39+
gradients[i + 1] += (prediction - y_i) * x_i;
40+
}
41+
}
42+
43+
gradients
44+
}
45+
46+
#[cfg(test)]
47+
mod test {
48+
use super::*;
49+
50+
#[test]
51+
fn test_logistic_regression_simple() {
52+
let data = vec![
53+
(vec![0.0], 0.0),
54+
(vec![1.0], 0.0),
55+
(vec![2.0], 0.0),
56+
(vec![3.0], 1.0),
57+
(vec![4.0], 1.0),
58+
(vec![5.0], 1.0),
59+
];
60+
61+
let result = logistic_regression(data, 10000, 0.05);
62+
assert!(result.is_some());
63+
64+
let params = result.unwrap();
65+
assert!((params[0] + 17.65).abs() < 1.0);
66+
assert!((params[1] - 7.13).abs() < 1.0);
67+
}
68+
69+
#[test]
70+
fn test_logistic_regression_extreme_data() {
71+
let data = vec![
72+
(vec![-100.0], 0.0),
73+
(vec![-10.0], 0.0),
74+
(vec![0.0], 0.0),
75+
(vec![10.0], 1.0),
76+
(vec![100.0], 1.0),
77+
];
78+
79+
let result = logistic_regression(data, 10000, 0.05);
80+
assert!(result.is_some());
81+
82+
let params = result.unwrap();
83+
assert!((params[0] + 6.20).abs() < 1.0);
84+
assert!((params[1] - 5.5).abs() < 1.0);
85+
}
86+
87+
#[test]
88+
fn test_logistic_regression_no_data() {
89+
let result = logistic_regression(vec![], 5000, 0.1);
90+
assert_eq!(result, None);
91+
}
92+
}

src/machine_learning/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
mod cholesky;
22
mod k_means;
33
mod linear_regression;
4+
mod logistic_regression;
45
mod loss_function;
56
mod optimization;
67

78
pub use self::cholesky::cholesky;
89
pub use self::k_means::k_means;
910
pub use self::linear_regression::linear_regression;
11+
pub use self::logistic_regression::logistic_regression;
1012
pub use self::loss_function::average_margin_ranking_loss;
1113
pub use self::loss_function::hng_loss;
1214
pub use self::loss_function::huber_loss;

src/machine_learning/optimization/gradient_descent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
/// A reference to the optimized parameter vector `x`.
2424
2525
pub fn gradient_descent(
26-
derivative_fn: fn(&[f64]) -> Vec<f64>,
26+
derivative_fn: impl Fn(&[f64]) -> Vec<f64>,
2727
x: &mut Vec<f64>,
2828
learning_rate: f64,
2929
num_iterations: i32,

0 commit comments

Comments
 (0)