Skip to content

Commit a30e5f1

Browse files
authored
[linfa-svm] Fix SVR nu parameter passing and rework SVR parameterization API (#370)
* New API for SVR parameterization * Fix nu parameter passing * Add SVR example * Upload code coverage on pull request only * Bump linfa-svm version to 0.7.2 * Try to compute code coverage only on PR on master * Add SVR test with polynomial kernel * Fix deprecated functions (as it should have been wired) * Test rewired deprecated API
1 parent 936680a commit a30e5f1

File tree

5 files changed

+140
-54
lines changed

5 files changed

+140
-54
lines changed

.github/workflows/codequality.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ jobs:
3030
run: cargo clippy --all-targets -- -D warnings
3131

3232
coverage:
33+
needs: codequality
3334
name: coverage
3435
runs-on: ubuntu-latest
35-
if: github.event.pull_request.draft == false
36+
if: github.event.pull_request.draft == false && (github.event_name == 'pull_request' || github.ref == 'refs/heads/master')
3637

3738
steps:
3839
- name: Checkout sources
@@ -65,4 +66,4 @@ jobs:
6566
with:
6667
token: ${{ secrets.CODECOV_TOKEN }}
6768
fail_ci_if_error: true
68-
verbose: true
69+

algorithms/linfa-svm/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "linfa-svm"
3-
version = "0.7.1"
3+
version = "0.7.2"
44
edition = "2018"
55
authors = ["Lorenz Schmidt <[email protected]>"]
66
description = "Support Vector Machines"
@@ -33,6 +33,9 @@ linfa = { version = "0.7.1", path = "../.." }
3333
linfa-kernel = { version = "0.7.1", path = "../linfa-kernel" }
3434

3535
[dev-dependencies]
36-
linfa-datasets = { version = "0.7.1", path = "../../datasets", features = ["winequality", "diabetes"] }
36+
linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [
37+
"winequality",
38+
"diabetes",
39+
] }
3740
rand_xoshiro = "0.6"
3841
approx = "0.4"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use linfa::prelude::*;
2+
use linfa_svm::{error::Result, Svm};
3+
use ndarray::Array1;
4+
use ndarray_rand::{
5+
rand::{Rng, SeedableRng},
6+
rand_distr::Uniform,
7+
};
8+
use rand_xoshiro::Xoshiro256Plus;
9+
10+
/// Example inspired by https://scikit-learn.org/stable/auto_examples/svm/plot_svm_regression.html
11+
fn main() -> Result<()> {
12+
let mut rng = Xoshiro256Plus::seed_from_u64(42);
13+
let range = Uniform::new(0., 5.);
14+
let mut x: Vec<f64> = (0..40).map(|_| rng.sample(range)).collect();
15+
x.sort_by(|a, b| a.partial_cmp(b).unwrap());
16+
let x = Array1::from_vec(x);
17+
18+
let mut y = x.mapv(|v| v.sin());
19+
20+
// add some noise
21+
y.iter_mut()
22+
.enumerate()
23+
.filter(|(i, _)| i % 5 == 0)
24+
.for_each(|(_, y)| *y = 3. * (0.5 - rng.gen::<f64>()));
25+
26+
let x = x.into_shape((40, 1)).unwrap();
27+
let dataset = DatasetBase::new(x, y);
28+
let model = Svm::params()
29+
.c_svr(100., Some(0.1))
30+
.gaussian_kernel(10.)
31+
.fit(&dataset)?;
32+
33+
println!("{}", model);
34+
35+
let predicted = model.predict(&dataset);
36+
let err = predicted.mean_squared_error(&dataset).unwrap();
37+
println!("err={}", err);
38+
39+
Ok(())
40+
}

algorithms/linfa-svm/src/hyperparams.rs

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl<F: Float, T> SvmParams<F, T> {
134134
}
135135

136136
/// Sets the model to use the Polynomial kernel. For this kernel the
137-
/// distance between two points is computed as: `d(x, x') = (<x, x'> + costant)^(degree)`
137+
/// distance between two points is computed as: `d(x, x') = (<x, x'> + constant)^(degree)`
138138
pub fn polynomial_kernel(mut self, constant: F, degree: F) -> Self {
139139
self.0.kernel = Kernel::params().method(KernelMethod::Polynomial(constant, degree));
140140
self
@@ -168,16 +168,36 @@ impl<F: Float, T> SvmParams<F, T> {
168168
}
169169

170170
impl<F: Float> SvmParams<F, F> {
171-
/// Set the C value for regression
171+
/// Set the C value for regression and solver epsilon stopping condition.
172+
/// Loss epsilon value is fixed at 0.1.
173+
#[deprecated(since = "0.7.2", note = "Use .c_svr() and .eps()")]
172174
pub fn c_eps(mut self, c: F, eps: F) -> Self {
173-
self.0.c = Some((c, eps));
175+
self.0.c = Some((c, F::cast(0.1)));
174176
self.0.nu = None;
177+
self.0.solver_params.eps = eps;
175178
self
176179
}
177180

178-
/// Set the Nu-Eps value for regression
181+
/// Set the Nu value for regression and solver epsilon stopping condition.
182+
/// C value used value is fixed at 1.0.
183+
#[deprecated(since = "0.7.2", note = "Use .nu_svr() and .eps()")]
179184
pub fn nu_eps(mut self, nu: F, eps: F) -> Self {
180-
self.0.nu = Some((nu, eps));
185+
self.0.nu = Some((nu, F::one()));
186+
self.0.c = None;
187+
self.0.solver_params.eps = eps;
188+
self
189+
}
190+
191+
/// Set the C value and optionnaly an epsilon value used in loss function (default 0.1) for regression
192+
pub fn c_svr(mut self, c: F, loss_eps: Option<F>) -> Self {
193+
self.0.c = Some((c, loss_eps.unwrap_or(F::cast(0.1))));
194+
self.0.nu = None;
195+
self
196+
}
197+
198+
/// Set the Nu and optionally a C value (default 1.) for regression
199+
pub fn nu_svr(mut self, nu: F, c: Option<F>) -> Self {
200+
self.0.nu = Some((nu, c.unwrap_or(F::one())));
181201
self.0.c = None;
182202
self
183203
}
@@ -219,7 +239,7 @@ impl<F: Float, L> ParamGuard for SvmParams<F, L> {
219239
}
220240
}
221241
if let Some((nu, _)) = self.0.nu {
222-
if nu <= F::zero() {
242+
if nu <= F::zero() || nu > F::one() {
223243
return Err(SvmError::InvalidNu(nu.to_f32().unwrap()));
224244
}
225245
}

algorithms/linfa-svm/src/regression.rs

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ pub fn fit_nu<F: Float>(
7777
dataset: ArrayView2<F>,
7878
kernel: Kernel<F>,
7979
target: &[F],
80-
c: F,
8180
nu: F,
81+
c: F,
8282
) -> Svm<F, F> {
8383
let mut alpha = vec![F::zero(); 2 * target.len()];
8484
let mut linear_term = vec![F::zero(); 2 * target.len()];
@@ -128,21 +128,21 @@ macro_rules! impl_regression {
128128
let target = target.as_slice().unwrap();
129129

130130
let ret = match (self.c(), self.nu()) {
131-
(Some((c, eps)), _) => fit_epsilon(
131+
(Some((c, p)), _) => fit_epsilon(
132132
self.solver_params().clone(),
133133
dataset.records().view(),
134134
kernel,
135135
target,
136136
c,
137-
eps,
137+
p,
138138
),
139-
(None, Some((nu, eps))) => fit_nu(
139+
(None, Some((nu, c))) => fit_nu(
140140
self.solver_params().clone(),
141141
dataset.records().view(),
142142
kernel,
143143
target,
144144
nu,
145-
eps,
145+
c,
146146
),
147147
_ => panic!("Set either C value or Nu value"),
148148
};
@@ -206,73 +206,95 @@ pub mod tests {
206206
use linfa::dataset::Dataset;
207207
use linfa::metrics::SingleTargetRegression;
208208
use linfa::traits::{Fit, Predict};
209-
use ndarray::Array;
210-
211-
#[test]
212-
fn test_linear_epsilon_regression() -> Result<()> {
213-
let target = Array::linspace(0f64, 10., 100);
214-
let mut sin_curve = Array::zeros((100, 1));
215-
for (i, val) in target.iter().enumerate() {
216-
sin_curve[(i, 0)] = *val;
217-
}
218-
219-
let dataset = Dataset::new(sin_curve, target);
220-
221-
let model = Svm::params()
222-
.nu_eps(2., 0.01)
223-
.gaussian_kernel(50.)
224-
.fit(&dataset)?;
209+
use linfa::DatasetBase;
210+
use ndarray::{Array, Array1, Array2};
225211

212+
fn _check_model(model: Svm<f64, f64>, dataset: &DatasetBase<Array2<f64>, Array1<f64>>) {
226213
println!("{}", model);
227-
228214
let predicted = model.predict(dataset.records());
215+
let err = predicted.mean_squared_error(&dataset).unwrap();
216+
println!("err={}", err);
229217
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
230-
231-
Ok(())
232218
}
233219

234220
#[test]
235-
fn test_linear_nu_regression() -> Result<()> {
236-
let target = Array::linspace(0f64, 10., 100);
237-
let mut sin_curve = Array::zeros((100, 1));
238-
for (i, val) in target.iter().enumerate() {
239-
sin_curve[(i, 0)] = *val;
240-
}
241-
242-
let dataset = Dataset::new(sin_curve, target);
221+
fn test_epsilon_regression_linear() -> Result<()> {
222+
// simple 2d straight line
223+
let targets = Array::linspace(0f64, 10., 100);
224+
let records = targets.clone().into_shape((100, 1)).unwrap();
225+
let dataset = Dataset::new(records, targets);
243226

244227
let model = Svm::params()
245-
.nu_eps(2., 0.01)
246-
.gaussian_kernel(50.)
228+
.c_svr(5., None)
229+
.linear_kernel()
247230
.fit(&dataset)?;
231+
_check_model(model, &dataset);
248232

249-
println!("{}", model);
250-
251-
let predicted = model.predict(&dataset);
252-
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
233+
// Old API
234+
#[allow(deprecated)]
235+
let model2 = Svm::params()
236+
.c_eps(5., 1e-3)
237+
.linear_kernel()
238+
.fit(&dataset)?;
239+
_check_model(model2, &dataset);
253240

254241
Ok(())
255242
}
256243

257244
#[test]
258-
fn test_regression_linear_kernel() -> Result<()> {
245+
fn test_nu_regression_linear() -> Result<()> {
259246
// simple 2d straight line
260247
let targets = Array::linspace(0f64, 10., 100);
261248
let records = targets.clone().into_shape((100, 1)).unwrap();
262-
263249
let dataset = Dataset::new(records, targets);
264250

265251
// Test the precomputed dot product in the linear kernel case
266252
let model = Svm::params()
267-
.nu_eps(2., 0.01)
253+
.nu_svr(0.5, Some(1.))
268254
.linear_kernel()
269255
.fit(&dataset)?;
256+
_check_model(model, &dataset);
270257

271-
println!("{}", model);
258+
// Old API
259+
#[allow(deprecated)]
260+
let model2 = Svm::params()
261+
.nu_eps(0.5, 1e-3)
262+
.linear_kernel()
263+
.fit(&dataset)?;
264+
_check_model(model2, &dataset);
265+
Ok(())
266+
}
272267

273-
let predicted = model.predict(&dataset);
274-
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
268+
#[test]
269+
fn test_epsilon_regression_gaussian() -> Result<()> {
270+
let records = Array::linspace(0f64, 10., 100)
271+
.into_shape((100, 1))
272+
.unwrap();
273+
let sin_curve = records.mapv(|v| v.sin()).into_shape((100,)).unwrap();
274+
let dataset = Dataset::new(records, sin_curve);
275+
276+
let model = Svm::params()
277+
.c_svr(100., Some(0.1))
278+
.gaussian_kernel(10.)
279+
.eps(1e-3)
280+
.fit(&dataset)?;
281+
_check_model(model, &dataset);
282+
Ok(())
283+
}
284+
285+
#[test]
286+
fn test_nu_regression_polynomial() -> Result<()> {
287+
let n = 100;
288+
let records = Array::linspace(0f64, 5., n).into_shape((n, 1)).unwrap();
289+
let sin_curve = records.mapv(|v| v.sin()).into_shape((n,)).unwrap();
290+
let dataset = Dataset::new(records, sin_curve);
275291

292+
let model = Svm::params()
293+
.nu_svr(0.01, None)
294+
.polynomial_kernel(1., 3.)
295+
.eps(1e-3)
296+
.fit(&dataset)?;
297+
_check_model(model, &dataset);
276298
Ok(())
277299
}
278300
}

0 commit comments

Comments
 (0)