Skip to content

Commit e467474

Browse files
committed
Fixing the test cases and minor adjustment to the algorithm
1 parent 21850fe commit e467474

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

src/machine_learning/logistic_regression.rs

+38-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub fn logistic_regression(
1111
return None;
1212
}
1313

14-
let num_features = data_points[0].0.len();
14+
let num_features = data_points[0].0.len() + 1;
1515
let mut params = vec![0.0; num_features];
1616

1717
let derivative_fn = |params: &[f64]| derivative(params, &data_points);
@@ -26,11 +26,12 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
2626
let mut gradients = vec![0.0; num_features];
2727

2828
for (features, y_i) in data_points {
29-
let z = params.iter().zip(features).map(|(p, x)| p * x).sum::<f64>();
29+
let z = params[0] + params[1..].iter().zip(features).map(|(p, x)| p * x).sum::<f64>();
3030
let prediction = 1.0 / (1.0 + E.powf(-z));
3131

32+
gradients[0] += prediction - y_i;
3233
for (i, x_i) in features.iter().enumerate() {
33-
gradients[i] += (prediction - y_i) * x_i;
34+
gradients[i + 1] += (prediction - y_i) * x_i;
3435
}
3536
}
3637

@@ -42,21 +43,46 @@ mod test {
4243
use super::*;
4344

4445
#[test]
45-
fn test_logistic_regression() {
46+
fn test_logistic_regression_simple() {
4647
let data = vec![
47-
(vec![0.0, 0.0], 0.0),
48-
(vec![1.0, 1.0], 1.0),
49-
(vec![2.0, 2.0], 1.0),
48+
(vec![0.0], 0.0),
49+
(vec![1.0], 0.0),
50+
(vec![2.0], 0.0),
51+
(vec![3.0], 1.0),
52+
(vec![4.0], 1.0),
53+
(vec![5.0], 1.0),
5054
];
51-
let result = logistic_regression(data, 10000, 0.1);
55+
56+
let result = logistic_regression(data, 10000, 0.05);
57+
assert!(result.is_some());
58+
59+
let params = result.unwrap();
60+
assert!((params[0] + 17.65).abs() < 1.0);
61+
assert!((params[1] - 7.13).abs() < 1.0);
62+
}
63+
64+
#[test]
65+
fn test_logistic_regression_extreme_data() {
66+
let data = vec![
67+
(vec![-100.0], 0.0),
68+
(vec![-10.0], 0.0),
69+
(vec![0.0], 0.0),
70+
(vec![10.0], 1.0),
71+
(vec![100.0], 1.0),
72+
];
73+
74+
let result = logistic_regression(data, 10000, 0.05);
5275
assert!(result.is_some());
76+
5377
let params = result.unwrap();
54-
assert!((params[0] - 6.902976808251308).abs() < 1e-6);
55-
assert!((params[1] - 2000.4659358334482).abs() < 1e-6);
78+
assert!((params[0] + 6.20).abs() < 1.0);
79+
assert!((params[1] - 5.5).abs() < 1.0);
5680
}
5781

5882
#[test]
59-
fn test_empty_list_logistic_regression() {
60-
assert_eq!(logistic_regression(vec![], 10000, 0.1), None);
83+
fn test_logistic_regression_no_data() {
84+
let result = logistic_regression(vec![], 5000, 0.1);
85+
assert_eq!(result, None);
6186
}
6287
}
88+

0 commit comments

Comments
 (0)