@@ -11,7 +11,7 @@ pub fn logistic_regression(
11
11
return None ;
12
12
}
13
13
14
- let num_features = data_points[ 0 ] . 0 . len ( ) ;
14
+ let num_features = data_points[ 0 ] . 0 . len ( ) + 1 ;
15
15
let mut params = vec ! [ 0.0 ; num_features] ;
16
16
17
17
let derivative_fn = |params : & [ f64 ] | derivative ( params, & data_points) ;
@@ -26,11 +26,12 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
26
26
let mut gradients = vec ! [ 0.0 ; num_features] ;
27
27
28
28
for ( features, y_i) in data_points {
29
- let z = params. iter ( ) . zip ( features) . map ( |( p, x) | p * x) . sum :: < f64 > ( ) ;
29
+ let z = params[ 0 ] + params [ 1 .. ] . iter ( ) . zip ( features) . map ( |( p, x) | p * x) . sum :: < f64 > ( ) ;
30
30
let prediction = 1.0 / ( 1.0 + E . powf ( -z) ) ;
31
31
32
+ gradients[ 0 ] += prediction - y_i;
32
33
for ( i, x_i) in features. iter ( ) . enumerate ( ) {
33
- gradients[ i] += ( prediction - y_i) * x_i;
34
+ gradients[ i + 1 ] += ( prediction - y_i) * x_i;
34
35
}
35
36
}
36
37
@@ -42,21 +43,46 @@ mod test {
42
43
use super :: * ;
43
44
44
45
#[ test]
45
- fn test_logistic_regression ( ) {
46
+ fn test_logistic_regression_simple ( ) {
46
47
let data = vec ! [
47
- ( vec![ 0.0 , 0.0 ] , 0.0 ) ,
48
- ( vec![ 1.0 , 1.0 ] , 1.0 ) ,
49
- ( vec![ 2.0 , 2.0 ] , 1.0 ) ,
48
+ ( vec![ 0.0 ] , 0.0 ) ,
49
+ ( vec![ 1.0 ] , 0.0 ) ,
50
+ ( vec![ 2.0 ] , 0.0 ) ,
51
+ ( vec![ 3.0 ] , 1.0 ) ,
52
+ ( vec![ 4.0 ] , 1.0 ) ,
53
+ ( vec![ 5.0 ] , 1.0 ) ,
50
54
] ;
51
- let result = logistic_regression ( data, 10000 , 0.1 ) ;
55
+
56
+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
57
+ assert ! ( result. is_some( ) ) ;
58
+
59
+ let params = result. unwrap ( ) ;
60
+ assert ! ( ( params[ 0 ] + 17.65 ) . abs( ) < 1.0 ) ;
61
+ assert ! ( ( params[ 1 ] - 7.13 ) . abs( ) < 1.0 ) ;
62
+ }
63
+
64
+ #[ test]
65
+ fn test_logistic_regression_extreme_data ( ) {
66
+ let data = vec ! [
67
+ ( vec![ -100.0 ] , 0.0 ) ,
68
+ ( vec![ -10.0 ] , 0.0 ) ,
69
+ ( vec![ 0.0 ] , 0.0 ) ,
70
+ ( vec![ 10.0 ] , 1.0 ) ,
71
+ ( vec![ 100.0 ] , 1.0 ) ,
72
+ ] ;
73
+
74
+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
52
75
assert ! ( result. is_some( ) ) ;
76
+
53
77
let params = result. unwrap ( ) ;
54
- assert ! ( ( params[ 0 ] - 6.902976808251308 ) . abs( ) < 1e-6 ) ;
55
- assert ! ( ( params[ 1 ] - 2000.4659358334482 ) . abs( ) < 1e-6 ) ;
78
+ assert ! ( ( params[ 0 ] + 6.20 ) . abs( ) < 1.0 ) ;
79
+ assert ! ( ( params[ 1 ] - 5.5 ) . abs( ) < 1.0 ) ;
56
80
}
57
81
58
82
#[ test]
59
- fn test_empty_list_logistic_regression ( ) {
60
- assert_eq ! ( logistic_regression( vec![ ] , 10000 , 0.1 ) , None ) ;
83
+ fn test_logistic_regression_no_data ( ) {
84
+ let result = logistic_regression ( vec ! [ ] , 5000 , 0.1 ) ;
85
+ assert_eq ! ( result, None ) ;
61
86
}
62
87
}
88
+
0 commit comments