@@ -77,8 +77,8 @@ pub fn fit_nu<F: Float>(
7777 dataset : ArrayView2 < F > ,
7878 kernel : Kernel < F > ,
7979 target : & [ F ] ,
80- c : F ,
8180 nu : F ,
81+ c : F ,
8282) -> Svm < F , F > {
8383 let mut alpha = vec ! [ F :: zero( ) ; 2 * target. len( ) ] ;
8484 let mut linear_term = vec ! [ F :: zero( ) ; 2 * target. len( ) ] ;
@@ -128,21 +128,21 @@ macro_rules! impl_regression {
128128 let target = target. as_slice( ) . unwrap( ) ;
129129
130130 let ret = match ( self . c( ) , self . nu( ) ) {
131- ( Some ( ( c, eps ) ) , _) => fit_epsilon(
131+ ( Some ( ( c, p ) ) , _) => fit_epsilon(
132132 self . solver_params( ) . clone( ) ,
133133 dataset. records( ) . view( ) ,
134134 kernel,
135135 target,
136136 c,
137- eps ,
137+ p ,
138138 ) ,
139- ( None , Some ( ( nu, eps ) ) ) => fit_nu(
139+ ( None , Some ( ( nu, c ) ) ) => fit_nu(
140140 self . solver_params( ) . clone( ) ,
141141 dataset. records( ) . view( ) ,
142142 kernel,
143143 target,
144144 nu,
145- eps ,
145+ c ,
146146 ) ,
147147 _ => panic!( "Set either C value or Nu value" ) ,
148148 } ;
@@ -206,73 +206,95 @@ pub mod tests {
206206 use linfa:: dataset:: Dataset ;
207207 use linfa:: metrics:: SingleTargetRegression ;
208208 use linfa:: traits:: { Fit , Predict } ;
209- use ndarray:: Array ;
210-
211- #[ test]
212- fn test_linear_epsilon_regression ( ) -> Result < ( ) > {
213- let target = Array :: linspace ( 0f64 , 10. , 100 ) ;
214- let mut sin_curve = Array :: zeros ( ( 100 , 1 ) ) ;
215- for ( i, val) in target. iter ( ) . enumerate ( ) {
216- sin_curve[ ( i, 0 ) ] = * val;
217- }
218-
219- let dataset = Dataset :: new ( sin_curve, target) ;
220-
221- let model = Svm :: params ( )
222- . nu_eps ( 2. , 0.01 )
223- . gaussian_kernel ( 50. )
224- . fit ( & dataset) ?;
209+ use linfa:: DatasetBase ;
210+ use ndarray:: { Array , Array1 , Array2 } ;
225211
212+ fn _check_model ( model : Svm < f64 , f64 > , dataset : & DatasetBase < Array2 < f64 > , Array1 < f64 > > ) {
226213 println ! ( "{}" , model) ;
227-
228214 let predicted = model. predict ( dataset. records ( ) ) ;
215+ let err = predicted. mean_squared_error ( & dataset) . unwrap ( ) ;
216+ println ! ( "err={}" , err) ;
229217 assert ! ( predicted. mean_squared_error( & dataset) . unwrap( ) < 1e-2 ) ;
230-
231- Ok ( ( ) )
232218 }
233219
234220 #[ test]
235- fn test_linear_nu_regression ( ) -> Result < ( ) > {
236- let target = Array :: linspace ( 0f64 , 10. , 100 ) ;
237- let mut sin_curve = Array :: zeros ( ( 100 , 1 ) ) ;
238- for ( i, val) in target. iter ( ) . enumerate ( ) {
239- sin_curve[ ( i, 0 ) ] = * val;
240- }
241-
242- let dataset = Dataset :: new ( sin_curve, target) ;
221+ fn test_epsilon_regression_linear ( ) -> Result < ( ) > {
222+ // simple 2d straight line
223+ let targets = Array :: linspace ( 0f64 , 10. , 100 ) ;
224+ let records = targets. clone ( ) . into_shape ( ( 100 , 1 ) ) . unwrap ( ) ;
225+ let dataset = Dataset :: new ( records, targets) ;
243226
244227 let model = Svm :: params ( )
245- . nu_eps ( 2 ., 0.01 )
246- . gaussian_kernel ( 50. )
228+ . c_svr ( 5 ., None )
229+ . linear_kernel ( )
247230 . fit ( & dataset) ?;
231+ _check_model ( model, & dataset) ;
248232
249- println ! ( "{}" , model) ;
250-
251- let predicted = model. predict ( & dataset) ;
252- assert ! ( predicted. mean_squared_error( & dataset) . unwrap( ) < 1e-2 ) ;
233+ // Old API
234+ #[ allow( deprecated) ]
235+ let model2 = Svm :: params ( )
236+ . c_eps ( 5. , 1e-3 )
237+ . linear_kernel ( )
238+ . fit ( & dataset) ?;
239+ _check_model ( model2, & dataset) ;
253240
254241 Ok ( ( ) )
255242 }
256243
257244 #[ test]
258- fn test_regression_linear_kernel ( ) -> Result < ( ) > {
245+ fn test_nu_regression_linear ( ) -> Result < ( ) > {
259246 // simple 2d straight line
260247 let targets = Array :: linspace ( 0f64 , 10. , 100 ) ;
261248 let records = targets. clone ( ) . into_shape ( ( 100 , 1 ) ) . unwrap ( ) ;
262-
263249 let dataset = Dataset :: new ( records, targets) ;
264250
265251 // Test the precomputed dot product in the linear kernel case
266252 let model = Svm :: params ( )
267- . nu_eps ( 2. , 0.01 )
253+ . nu_svr ( 0.5 , Some ( 1. ) )
268254 . linear_kernel ( )
269255 . fit ( & dataset) ?;
256+ _check_model ( model, & dataset) ;
270257
271- println ! ( "{}" , model) ;
258+ // Old API
259+ #[ allow( deprecated) ]
260+ let model2 = Svm :: params ( )
261+ . nu_eps ( 0.5 , 1e-3 )
262+ . linear_kernel ( )
263+ . fit ( & dataset) ?;
264+ _check_model ( model2, & dataset) ;
265+ Ok ( ( ) )
266+ }
272267
273- let predicted = model. predict ( & dataset) ;
274- assert ! ( predicted. mean_squared_error( & dataset) . unwrap( ) < 1e-2 ) ;
268+ #[ test]
269+ fn test_epsilon_regression_gaussian ( ) -> Result < ( ) > {
270+ let records = Array :: linspace ( 0f64 , 10. , 100 )
271+ . into_shape ( ( 100 , 1 ) )
272+ . unwrap ( ) ;
273+ let sin_curve = records. mapv ( |v| v. sin ( ) ) . into_shape ( ( 100 , ) ) . unwrap ( ) ;
274+ let dataset = Dataset :: new ( records, sin_curve) ;
275+
276+ let model = Svm :: params ( )
277+ . c_svr ( 100. , Some ( 0.1 ) )
278+ . gaussian_kernel ( 10. )
279+ . eps ( 1e-3 )
280+ . fit ( & dataset) ?;
281+ _check_model ( model, & dataset) ;
282+ Ok ( ( ) )
283+ }
284+
285+ #[ test]
286+ fn test_nu_regression_polynomial ( ) -> Result < ( ) > {
287+ let n = 100 ;
288+ let records = Array :: linspace ( 0f64 , 5. , n) . into_shape ( ( n, 1 ) ) . unwrap ( ) ;
289+ let sin_curve = records. mapv ( |v| v. sin ( ) ) . into_shape ( ( n, ) ) . unwrap ( ) ;
290+ let dataset = Dataset :: new ( records, sin_curve) ;
275291
292+ let model = Svm :: params ( )
293+ . nu_svr ( 0.01 , None )
294+ . polynomial_kernel ( 1. , 3. )
295+ . eps ( 1e-3 )
296+ . fit ( & dataset) ?;
297+ _check_model ( model, & dataset) ;
276298 Ok ( ( ) )
277299 }
278300}
0 commit comments