-
-
Notifications
You must be signed in to change notification settings - Fork 260
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[linfa-svm] Fix SVR nu parameter passing and rework SVR parameterizat…
…ion API (#370) * New API for SVR parameterization * Fix nu parameter passing * Add SVR example * Upload code coverage on pull request only * Bump linfa-svm version to 0.7.2 * Try to compute code coverage only on PR on master * Add SVR test with polynomial kernel * Fix deprecated functions (as it should have been wired) * Test rewired deprecated API
- Loading branch information
Showing
5 changed files
with
140 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "linfa-svm" | ||
version = "0.7.1" | ||
version = "0.7.2" | ||
edition = "2018" | ||
authors = ["Lorenz Schmidt <[email protected]>"] | ||
description = "Support Vector Machines" | ||
|
@@ -33,6 +33,9 @@ linfa = { version = "0.7.1", path = "../.." } | |
linfa-kernel = { version = "0.7.1", path = "../linfa-kernel" } | ||
|
||
[dev-dependencies] | ||
linfa-datasets = { version = "0.7.1", path = "../../datasets", features = ["winequality", "diabetes"] } | ||
linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ | ||
"winequality", | ||
"diabetes", | ||
] } | ||
rand_xoshiro = "0.6" | ||
approx = "0.4" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
use linfa::prelude::*; | ||
use linfa_svm::{error::Result, Svm}; | ||
use ndarray::Array1; | ||
use ndarray_rand::{ | ||
rand::{Rng, SeedableRng}, | ||
rand_distr::Uniform, | ||
}; | ||
use rand_xoshiro::Xoshiro256Plus; | ||
|
||
/// Example inspired by https://scikit-learn.org/stable/auto_examples/svm/plot_svm_regression.html | ||
fn main() -> Result<()> { | ||
let mut rng = Xoshiro256Plus::seed_from_u64(42); | ||
let range = Uniform::new(0., 5.); | ||
let mut x: Vec<f64> = (0..40).map(|_| rng.sample(range)).collect(); | ||
x.sort_by(|a, b| a.partial_cmp(b).unwrap()); | ||
let x = Array1::from_vec(x); | ||
|
||
let mut y = x.mapv(|v| v.sin()); | ||
|
||
// add some noise | ||
y.iter_mut() | ||
.enumerate() | ||
.filter(|(i, _)| i % 5 == 0) | ||
.for_each(|(_, y)| *y = 3. * (0.5 - rng.gen::<f64>())); | ||
|
||
let x = x.into_shape((40, 1)).unwrap(); | ||
let dataset = DatasetBase::new(x, y); | ||
let model = Svm::params() | ||
.c_svr(100., Some(0.1)) | ||
.gaussian_kernel(10.) | ||
.fit(&dataset)?; | ||
|
||
println!("{}", model); | ||
|
||
let predicted = model.predict(&dataset); | ||
let err = predicted.mean_squared_error(&dataset).unwrap(); | ||
println!("err={}", err); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters