diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8ad88a8..cfda8cf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Install egui dependencies - run: sudo apt-get install libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev libspeechd-dev libxkbcommon-dev libssl-dev + run: sudo apt-get install -y libclang-dev libgtk-3-dev libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev libxkbcommon-dev libssl-dev - name: Build run: cargo build --release --verbose --all-features - name: Run tests diff --git a/.gitignore b/.gitignore index b7b4f51..eb53182 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ Cargo.lock .idea/ /examples/*.aml /examples/*.yaml -/examples/*.sc \ No newline at end of file +/examples/*.sc +.vscode diff --git a/Cargo.toml b/Cargo.toml index 8165e32..8b43581 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "automl" -version = "0.2.7" +version = "0.3.0" authors = ["Chris McComb "] description = "Automated machine learning for classification and regression" edition = "2021" diff --git a/README.md b/README.md index 9379cdd..b2a7eb4 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,16 @@ [![docs.rs](https://img.shields.io/docsrs/automl/latest?logo=rust)](https://docs.rs/automl) # AutoML with SmartCore + AutoML is _Automated Machine Learning_, referring to processes and methods to make machine learning more accessible for a general audience. This crate builds on top of the [smartcore](https://docs.rs/smartcore/) machine learning framework, and provides some utilities to quickly train and compare models. # Install + To use the latest released version of `AutoML`, add this to your `Cargo.toml`: ```toml -automl = "0.2.7" +automl = "0.3.0" ``` To use the bleeding edge instead, add this: ```toml @@ -18,14 +20,18 @@ automl = { git = "https://github.com/cmccomb/rust-automl" } ``` # Usage + Running the following: + ```rust let dataset = smartcore::dataset::breast_cancer::load_dataset(); let settings = automl::Settings::default_classification(); let mut classifier = automl::SupervisedModel::new(dataset, settings); classifier.train(); ``` + will perform a comparison of classifier models using cross-validation. Printing the classifier object will yield: + ```text ┌────────────────────────────────┬─────────────────────┬───────────────────┬──────────────────┐ │ Model │ Time │ Training Accuracy │ Testing Accuracy │ @@ -45,17 +51,20 @@ will perform a comparison of classifier models using cross-validation. Printing │ Support Vector Classifier │ 4s 187ms 61us 708ns │ 0.57 │ 0.57 │ └────────────────────────────────┴─────────────────────┴───────────────────┴──────────────────┘ ``` + You can then perform inference using the best model with the `predict` method. ## Features -This crate has several features that add some additional methods -| Feature | Description | -|:----------|:----------------------------------------------------------------------------------------------------------| -| `nd` | Adds methods for predicting/reading data using [`ndarray`](https://crates.io/crates/ndarray). | -| `csv` | Adds methods for predicting/reading data from a .csv using [`polars`](https://crates.io/crates/polars). | +This crate has several features that add some additional methods. + +| Feature | Description | +| :------ | :------------------------------------------------------------------------------------------------------ | +| `nd` | Adds methods for predicting/reading data using [`ndarray`](https://crates.io/crates/ndarray). | +| `csv` | Adds methods for predicting/reading data from a .csv using [`polars`](https://crates.io/crates/polars). | ## Capabilities + - Feature Engineering - PCA - SVD diff --git a/examples/classification_save_best.rs b/examples/classification_save_best.rs index 7b364ff..7d33258 100644 --- a/examples/classification_save_best.rs +++ b/examples/classification_save_best.rs @@ -17,7 +17,7 @@ fn main() { // Load that model for use directly in SmartCore let mut buf: Vec = Vec::new(); - std::fs::File::open(&file_name) + std::fs::File::open(file_name) .and_then(|mut f| f.read_to_end(&mut buf)) .expect("Cannot load model from file."); let model: LogisticRegression> = diff --git a/src/algorithms/categorical_naive_bayes_classifier.rs b/src/algorithms/categorical_naive_bayes_classifier.rs index 8494f06..cad8fef 100644 --- a/src/algorithms/categorical_naive_bayes_classifier.rs +++ b/src/algorithms/categorical_naive_bayes_classifier.rs @@ -1,11 +1,18 @@ -use smartcore::linalg::naive::dense_matrix::DenseMatrix; -use smartcore::model_selection::cross_validate; -use smartcore::naive_bayes::categorical::CategoricalNB; +//! Categorical Naive Bayes Classifier. + +use smartcore::{ + linalg::naive::dense_matrix::DenseMatrix, + model_selection::{cross_validate, CrossValidationResult}, + naive_bayes::categorical::CategoricalNB, +}; use crate::{Algorithm, Settings}; -use smartcore::model_selection::CrossValidationResult; -pub(crate) struct CategoricalNaiveBayesClassifierWrapper {} +/// The Categorical Naive Bayes Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/naive_bayes.html#categorical-naive-bayes) +/// for a more in-depth description of the algorithm. +pub struct CategoricalNaiveBayesClassifierWrapper {} impl super::ModelWrapper for CategoricalNaiveBayesClassifierWrapper { fn cv( @@ -41,7 +48,7 @@ impl super::ModelWrapper for CategoricalNaiveBayesClassifierWrapper { fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { let model: CategoricalNB> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/decision_tree_classifier.rs b/src/algorithms/decision_tree_classifier.rs index 6496cf7..fa90225 100644 --- a/src/algorithms/decision_tree_classifier.rs +++ b/src/algorithms/decision_tree_classifier.rs @@ -1,3 +1,5 @@ +//! Decision Tree Classifier. + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, model_selection::{cross_validate, CrossValidationResult}, @@ -6,7 +8,11 @@ use smartcore::{ use crate::{Algorithm, Settings}; -pub(crate) struct DecisionTreeClassifierWrapper {} +/// The Decision Tree Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/tree.html#classification) +/// for a more in-depth description of the algorithm. +pub struct DecisionTreeClassifierWrapper {} impl super::ModelWrapper for DecisionTreeClassifierWrapper { fn cv( @@ -49,7 +55,7 @@ impl super::ModelWrapper for DecisionTreeClassifierWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: DecisionTreeClassifier = bincode::deserialize(&*final_model).unwrap(); + let model: DecisionTreeClassifier = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/decision_tree_regressor.rs b/src/algorithms/decision_tree_regressor.rs index 58fe63d..4bb6dc2 100644 --- a/src/algorithms/decision_tree_regressor.rs +++ b/src/algorithms/decision_tree_regressor.rs @@ -1,3 +1,5 @@ +//! Decision Tree Regressor. + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, model_selection::{cross_validate, CrossValidationResult}, @@ -6,7 +8,11 @@ use smartcore::{ use crate::{Algorithm, Settings}; -pub(crate) struct DecisionTreeRegressorWrapper {} +/// The Decision Tree Regressor. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/tree.html#regression) +/// for a more in-depth description of the algorithm. +pub struct DecisionTreeRegressorWrapper {} impl super::ModelWrapper for DecisionTreeRegressorWrapper { fn cv( @@ -49,7 +55,7 @@ impl super::ModelWrapper for DecisionTreeRegressorWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: DecisionTreeRegressor = bincode::deserialize(&*final_model).unwrap(); + let model: DecisionTreeRegressor = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/elastic_net_regressor.rs b/src/algorithms/elastic_net_regressor.rs index 528d6ef..86eff40 100644 --- a/src/algorithms/elastic_net_regressor.rs +++ b/src/algorithms/elastic_net_regressor.rs @@ -1,3 +1,5 @@ +//! Elastic Net Regressor. + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, linear::elastic_net::ElasticNet, model_selection::cross_validate, model_selection::CrossValidationResult, @@ -5,7 +7,11 @@ use smartcore::{ use crate::{Algorithm, Settings}; -pub(crate) struct ElasticNetRegressorWrapper {} +/// The Elastic Net Regressor. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#elastic-net) +/// for a more in-depth description of the algorithm. +pub struct ElasticNetRegressorWrapper {} impl super::ModelWrapper for ElasticNetRegressorWrapper { fn cv( @@ -40,7 +46,7 @@ impl super::ModelWrapper for ElasticNetRegressorWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: ElasticNet> = bincode::deserialize(&*final_model).unwrap(); + let model: ElasticNet> = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/gaussian_naive_bayes_classifier.rs b/src/algorithms/gaussian_naive_bayes_classifier.rs index 3b3d91a..a638278 100644 --- a/src/algorithms/gaussian_naive_bayes_classifier.rs +++ b/src/algorithms/gaussian_naive_bayes_classifier.rs @@ -1,11 +1,18 @@ +//! Gaussian Naive Bayes Classifier + use smartcore::{ - linalg::naive::dense_matrix::DenseMatrix, model_selection::cross_validate, - model_selection::CrossValidationResult, naive_bayes::gaussian::GaussianNB, + linalg::naive::dense_matrix::DenseMatrix, + model_selection::{cross_validate, CrossValidationResult}, + naive_bayes::gaussian::GaussianNB, }; use crate::{Algorithm, Settings}; -pub(crate) struct GaussianNaiveBayesClassifierWrapper {} +/// The Gaussian Naive Bayes Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/naive_bayes.html#gaussian-naive-bayes) +/// for a more in-depth description of the algorithm. +pub struct GaussianNaiveBayesClassifierWrapper {} impl super::ModelWrapper for GaussianNaiveBayesClassifierWrapper { fn cv( @@ -40,7 +47,7 @@ impl super::ModelWrapper for GaussianNaiveBayesClassifierWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: GaussianNB> = bincode::deserialize(&*final_model).unwrap(); + let model: GaussianNB> = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/knn_classifier.rs b/src/algorithms/knn_classifier.rs index 13ee395..f8f48c5 100644 --- a/src/algorithms/knn_classifier.rs +++ b/src/algorithms/knn_classifier.rs @@ -1,19 +1,24 @@ +//! KNN Classifier + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, math::distance::{ euclidian::Euclidian, hamming::Hamming, mahalanobis::Mahalanobis, manhattan::Manhattan, minkowski::Minkowski, Distances, }, - model_selection::cross_validate, + model_selection::{cross_validate, CrossValidationResult}, neighbors::knn_classifier::{ KNNClassifier, KNNClassifierParameters as SmartcoreKNNClassifierParameters, }, }; use crate::{Algorithm, Distance, Settings}; -use smartcore::model_selection::CrossValidationResult; -pub(crate) struct KNNClassifierWrapper {} +/// The KNN Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/neighbors.html#classification) +/// for a more in-depth description of the algorithm. +pub struct KNNClassifierWrapper {} impl super::ModelWrapper for KNNClassifierWrapper { fn cv( @@ -281,27 +286,26 @@ impl super::ModelWrapper for KNNClassifierWrapper { match settings.knn_classifier_settings.as_ref().unwrap().distance { Distance::Euclidean => { let model: KNNClassifier = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Manhattan => { let model: KNNClassifier = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Minkowski(_) => { let model: KNNClassifier = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Mahalanobis => { let model: KNNClassifier>> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Hamming => { - let model: KNNClassifier = - bincode::deserialize(&*final_model).unwrap(); + let model: KNNClassifier = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/knn_regressor.rs b/src/algorithms/knn_regressor.rs index 3d91594..7672db8 100644 --- a/src/algorithms/knn_regressor.rs +++ b/src/algorithms/knn_regressor.rs @@ -1,3 +1,5 @@ +//! KNN Regressor + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, math::distance::{ @@ -13,7 +15,11 @@ use smartcore::{ use crate::{Algorithm, Distance, Settings}; -pub(crate) struct KNNRegressorWrapper {} +/// The KNN Regressor. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/neighbors.html#regression) +/// for a more in-depth description of the algorithm. +pub struct KNNRegressorWrapper {} impl super::ModelWrapper for KNNRegressorWrapper { fn cv( @@ -288,27 +294,26 @@ impl super::ModelWrapper for KNNRegressorWrapper { match settings.knn_regressor_settings.as_ref().unwrap().distance { Distance::Euclidean => { let model: KNNRegressor = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Manhattan => { let model: KNNRegressor = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Minkowski(_) => { let model: KNNRegressor = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Mahalanobis => { let model: KNNRegressor>> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Distance::Hamming => { - let model: KNNRegressor = - bincode::deserialize(&*final_model).unwrap(); + let model: KNNRegressor = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/lasso_regressor.rs b/src/algorithms/lasso_regressor.rs index 19096cc..9e546c8 100644 --- a/src/algorithms/lasso_regressor.rs +++ b/src/algorithms/lasso_regressor.rs @@ -1,3 +1,5 @@ +//! LASSO regression algorithm. + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, linear::lasso::Lasso, model_selection::cross_validate, model_selection::CrossValidationResult, @@ -5,7 +7,11 @@ use smartcore::{ use crate::{Algorithm, Settings}; -pub(crate) struct LassoRegressorWrapper {} +/// The LASSO regression algorithm. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#lasso) +/// for a more in-depth description of the algorithm. +pub struct LassoRegressorWrapper {} impl super::ModelWrapper for LassoRegressorWrapper { fn cv( @@ -49,7 +55,7 @@ impl super::ModelWrapper for LassoRegressorWrapper { fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { let model: Lasso> = - bincode::deserialize(&*final_model).expect("Cannot deserialize trained model."); + bincode::deserialize(final_model).expect("Cannot deserialize trained model."); model.predict(x).expect("Error during inference.") } } diff --git a/src/algorithms/linear_regressor.rs b/src/algorithms/linear_regressor.rs index 16cabf6..97a42ad 100644 --- a/src/algorithms/linear_regressor.rs +++ b/src/algorithms/linear_regressor.rs @@ -1,11 +1,18 @@ +//! Linear regression algorithm. + use smartcore::{ - linalg::naive::dense_matrix::DenseMatrix, linear::linear_regression::LinearRegression, - model_selection::cross_validate, model_selection::CrossValidationResult, + linalg::naive::dense_matrix::DenseMatrix, + linear::linear_regression::LinearRegression, + model_selection::{cross_validate, CrossValidationResult}, }; use crate::{Algorithm, Settings}; -pub(crate) struct LinearRegressorWrapper {} +/// The Linear regression algorithm. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#ordinary-least-squares) +/// for a more in-depth description of the algorithm. +pub struct LinearRegressorWrapper {} impl super::ModelWrapper for LinearRegressorWrapper { fn cv( @@ -49,7 +56,7 @@ impl super::ModelWrapper for LinearRegressorWrapper { fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { let model: LinearRegression> = - bincode::deserialize(&*final_model).expect("Cannot deserialize trained model."); + bincode::deserialize(final_model).expect("Cannot deserialize trained model."); model.predict(x).expect("Error during inference.") } } diff --git a/src/algorithms/logistic_regression.rs b/src/algorithms/logistic_regression.rs index 82806fe..22b65b6 100644 --- a/src/algorithms/logistic_regression.rs +++ b/src/algorithms/logistic_regression.rs @@ -1,10 +1,16 @@ +//! Logistic Regression + use crate::{Algorithm, Settings}; use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, linear::logistic_regression::LogisticRegression, model_selection::cross_validate, model_selection::CrossValidationResult, }; -pub(crate) struct LogisticRegressionWrapper {} +/// The Logistic Regression algorithm. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression) +/// for a more in-depth description of the algorithm. +pub struct LogisticRegressionWrapper {} impl super::ModelWrapper for LogisticRegressionWrapper { fn cv( @@ -36,7 +42,7 @@ impl super::ModelWrapper for LogisticRegressionWrapper { fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { let model: LogisticRegression> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index c1339b8..1eb73de 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -1,47 +1,73 @@ +//! # Algorithms +//! +//! This module contains the wrappers for the algorithms provided by this crate. +//! The algorithms are all available through the common interface of the `ModelWrapper` trait. +//! +//! The available algorithms include: +//! +//! * Classification algorithms: +//! - Logistic Regression +//! - Random Forest Classifier +//! - K-Nearest Neighbors Classifier +//! - Decision Tree Classifier +//! - Gaussian Naive Bayes Classifier +//! - Categorical Naive Bayes Classifier +//! - Support Vector Classifier +//! +//! * Regression algorithms: +//! - Linear Regression +//! - Elastic Net Regressor +//! - Lasso Regressor +//! - K-Nearest Neighbors Regressor +//! - Ridge Regressor +//! - Random Forest Regressor +//! - Decision Tree Regressor +//! - Support Vector Regressor + mod linear_regressor; -pub(crate) use linear_regressor::LinearRegressorWrapper; +pub use linear_regressor::LinearRegressorWrapper; mod elastic_net_regressor; -pub(crate) use elastic_net_regressor::ElasticNetRegressorWrapper; +pub use elastic_net_regressor::ElasticNetRegressorWrapper; mod lasso_regressor; -pub(crate) use lasso_regressor::LassoRegressorWrapper; +pub use lasso_regressor::LassoRegressorWrapper; mod knn_regressor; -pub(crate) use knn_regressor::KNNRegressorWrapper; +pub use knn_regressor::KNNRegressorWrapper; mod ridge_regressor; -pub(crate) use ridge_regressor::RidgeRegressorWrapper; +pub use ridge_regressor::RidgeRegressorWrapper; mod logistic_regression; -pub(crate) use logistic_regression::LogisticRegressionWrapper; +pub use logistic_regression::LogisticRegressionWrapper; mod random_forest_classifier; -pub(crate) use random_forest_classifier::RandomForestClassifierWrapper; +pub use random_forest_classifier::RandomForestClassifierWrapper; mod random_forest_regressor; -pub(crate) use random_forest_regressor::RandomForestRegressorWrapper; +pub use random_forest_regressor::RandomForestRegressorWrapper; mod knn_classifier; -pub(crate) use knn_classifier::KNNClassifierWrapper; +pub use knn_classifier::KNNClassifierWrapper; mod decision_tree_classifier; -pub(crate) use decision_tree_classifier::DecisionTreeClassifierWrapper; +pub use decision_tree_classifier::DecisionTreeClassifierWrapper; mod decision_tree_regressor; -pub(crate) use decision_tree_regressor::DecisionTreeRegressorWrapper; +pub use decision_tree_regressor::DecisionTreeRegressorWrapper; mod gaussian_naive_bayes_classifier; -pub(crate) use gaussian_naive_bayes_classifier::GaussianNaiveBayesClassifierWrapper; +pub use gaussian_naive_bayes_classifier::GaussianNaiveBayesClassifierWrapper; mod categorical_naive_bayes_classifier; -pub(crate) use categorical_naive_bayes_classifier::CategoricalNaiveBayesClassifierWrapper; +pub use categorical_naive_bayes_classifier::CategoricalNaiveBayesClassifierWrapper; mod support_vector_classifier; -pub(crate) use support_vector_classifier::SupportVectorClassifierWrapper; +pub use support_vector_classifier::SupportVectorClassifierWrapper; mod support_vector_regressor; -pub(crate) use support_vector_regressor::SupportVectorRegressorWrapper; +pub use support_vector_regressor::SupportVectorRegressorWrapper; use crate::{Algorithm, Settings}; use smartcore::linalg::naive::dense_matrix::DenseMatrix; @@ -50,7 +76,22 @@ use smartcore::model_selection::CrossValidationResult; use crate::settings::FinalModel; use std::time::{Duration, Instant}; +/// Trait for wrapping models pub trait ModelWrapper { + /// Perform cross-validation and return the results + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `y` - The output data + /// * `settings` - The settings for the model + /// + /// # Returns + /// + /// * `CrossValidationResult` - The cross-validation results + /// * `Algorithm` - The algorithm used + /// * `Duration` - The time taken to perform the cross-validation + /// * `Vec` - The final model fn cv_model( x: &DenseMatrix, y: &Vec, @@ -70,16 +111,19 @@ pub trait ModelWrapper { ) } - // Perform cross-validation + /// Perform cross-validation + #[allow(clippy::ptr_arg)] fn cv( x: &DenseMatrix, y: &Vec, settings: &Settings, ) -> (CrossValidationResult, Algorithm); - // Train a model + /// Train a model + #[allow(clippy::ptr_arg)] fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec; - // Perform a prediction + /// Perform a prediction + #[allow(clippy::ptr_arg)] fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec; } diff --git a/src/algorithms/random_forest_classifier.rs b/src/algorithms/random_forest_classifier.rs index 76b8846..09eb111 100644 --- a/src/algorithms/random_forest_classifier.rs +++ b/src/algorithms/random_forest_classifier.rs @@ -1,12 +1,18 @@ +//! Random Forest Classifier + use smartcore::{ ensemble::random_forest_classifier::RandomForestClassifier, - linalg::naive::dense_matrix::DenseMatrix, model_selection::cross_validate, - model_selection::CrossValidationResult, + linalg::naive::dense_matrix::DenseMatrix, + model_selection::{cross_validate, CrossValidationResult}, }; use crate::{Algorithm, Settings}; -pub(crate) struct RandomForestClassifierWrapper {} +/// The Random Forest Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/ensemble.html#random-forests) +/// for a more in-depth description of the algorithm. +pub struct RandomForestClassifierWrapper {} impl super::ModelWrapper for RandomForestClassifierWrapper { fn cv( @@ -49,7 +55,7 @@ impl super::ModelWrapper for RandomForestClassifierWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: RandomForestClassifier = bincode::deserialize(&*final_model).unwrap(); + let model: RandomForestClassifier = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/random_forest_regressor.rs b/src/algorithms/random_forest_regressor.rs index 81bdee1..54642ba 100644 --- a/src/algorithms/random_forest_regressor.rs +++ b/src/algorithms/random_forest_regressor.rs @@ -1,12 +1,18 @@ +//! Random Forest Regressor + use smartcore::{ ensemble::random_forest_regressor::RandomForestRegressor, - linalg::naive::dense_matrix::DenseMatrix, model_selection::cross_validate, - model_selection::CrossValidationResult, + linalg::naive::dense_matrix::DenseMatrix, + model_selection::{cross_validate, CrossValidationResult}, }; use crate::{Algorithm, Settings}; -pub(crate) struct RandomForestRegressorWrapper {} +/// The Random Forest Regressor. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/ensemble.html#random-forests) +/// for a more in-depth description of the algorithm. +pub struct RandomForestRegressorWrapper {} impl super::ModelWrapper for RandomForestRegressorWrapper { fn cv( @@ -49,7 +55,7 @@ impl super::ModelWrapper for RandomForestRegressorWrapper { } fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { - let model: RandomForestRegressor = bincode::deserialize(&*final_model).unwrap(); + let model: RandomForestRegressor = bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/ridge_regressor.rs b/src/algorithms/ridge_regressor.rs index 9e48fa3..6ce0b25 100644 --- a/src/algorithms/ridge_regressor.rs +++ b/src/algorithms/ridge_regressor.rs @@ -1,3 +1,5 @@ +//! Ridge regression algorithm. + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, linear::ridge_regression::RidgeRegression, model_selection::cross_validate, model_selection::CrossValidationResult, @@ -5,7 +7,11 @@ use smartcore::{ use crate::{Algorithm, Settings}; -pub(crate) struct RidgeRegressorWrapper {} +/// The Ridge regression algorithm. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#ridge-regression) +/// for a more in-depth description of the algorithm. +pub struct RidgeRegressorWrapper {} impl super::ModelWrapper for RidgeRegressorWrapper { fn cv( @@ -36,7 +42,7 @@ impl super::ModelWrapper for RidgeRegressorWrapper { fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { let model: RidgeRegression> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/support_vector_classifier.rs b/src/algorithms/support_vector_classifier.rs index cfcadc3..4d76f06 100644 --- a/src/algorithms/support_vector_classifier.rs +++ b/src/algorithms/support_vector_classifier.rs @@ -1,3 +1,5 @@ +//! Support Vector Classifier + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, model_selection::cross_validate, @@ -10,7 +12,11 @@ use smartcore::{ use crate::{Algorithm, Kernel, Settings}; -pub(crate) struct SupportVectorClassifierWrapper {} +/// The Support Vector Classifier. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/svm.html#svm-classification) +/// for a more in-depth description of the algorithm. +pub struct SupportVectorClassifierWrapper {} impl super::ModelWrapper for SupportVectorClassifierWrapper { fn cv( @@ -120,22 +126,22 @@ impl super::ModelWrapper for SupportVectorClassifierWrapper { match settings.svc_settings.as_ref().unwrap().kernel { Kernel::Linear => { let model: SVC, LinearKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::Polynomial(_, _, _) => { let model: SVC, PolynomialKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::RBF(_) => { let model: SVC, RBFKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::Sigmoid(_, _) => { let model: SVC, SigmoidKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } diff --git a/src/algorithms/support_vector_regressor.rs b/src/algorithms/support_vector_regressor.rs index f336b6a..a34a557 100644 --- a/src/algorithms/support_vector_regressor.rs +++ b/src/algorithms/support_vector_regressor.rs @@ -1,3 +1,5 @@ +//! Support Vector Regressor + use smartcore::{ linalg::naive::dense_matrix::DenseMatrix, model_selection::cross_validate, @@ -10,7 +12,11 @@ use smartcore::{ use crate::{Algorithm, Kernel, Settings}; -pub(crate) struct SupportVectorRegressorWrapper {} +/// The Support Vector Regressor. +/// +/// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/svm.html#svm-regression) +/// for a more in-depth description of the algorithm. +pub struct SupportVectorRegressorWrapper {} impl super::ModelWrapper for SupportVectorRegressorWrapper { fn cv( @@ -120,84 +126,24 @@ impl super::ModelWrapper for SupportVectorRegressorWrapper { match settings.svr_settings.as_ref().unwrap().kernel { Kernel::Linear => { let model: SVR, LinearKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::Polynomial(_, _, _) => { let model: SVR, PolynomialKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::RBF(_) => { let model: SVR, RBFKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } Kernel::Sigmoid(_, _) => { let model: SVR, SigmoidKernel> = - bincode::deserialize(&*final_model).unwrap(); + bincode::deserialize(final_model).unwrap(); model.predict(x).unwrap() } } } } - -// -// let start = Instant::now(); -// let cv = match self.settings.svr_settings.as_ref().unwrap().kernel { -// Kernel::Linear => cross_validate( -// SVR::fit, -// &self.x, -// &self.y, -// SmartcoreSVRParameters::default() -// .with_tol(self.settings.svr_settings.as_ref().unwrap().tol) -// .with_c(self.settings.svr_settings.as_ref().unwrap().c) -// .with_eps(self.settings.svr_settings.as_ref().unwrap().c) -// .with_kernel(Kernels::linear()), -// self.get_kfolds(), -// metric, -// ) -// .unwrap(), -// Kernel::Polynomial(degree, gamma, coef) => cross_validate( -// SVR::fit, -// &self.x, -// &self.y, -// SmartcoreSVRParameters::default() -// .with_tol(self.settings.svr_settings.as_ref().unwrap().tol) -// .with_c(self.settings.svr_settings.as_ref().unwrap().c) -// .with_eps(self.settings.svr_settings.as_ref().unwrap().c) -// .with_kernel(Kernels::polynomial(degree, gamma, coef)), -// self.get_kfolds(), -// metric, -// ) -// .unwrap(), -// Kernel::RBF(gamma) => cross_validate( -// SVR::fit, -// &self.x, -// &self.y, -// SmartcoreSVRParameters::default() -// .with_tol(self.settings.svr_settings.as_ref().unwrap().tol) -// .with_c(self.settings.svr_settings.as_ref().unwrap().c) -// .with_eps(self.settings.svr_settings.as_ref().unwrap().c) -// .with_kernel(Kernels::rbf(gamma)), -// self.get_kfolds(), -// metric, -// ) -// .unwrap(), -// Kernel::Sigmoid(gamma, coef) => cross_validate( -// SVR::fit, -// &self.x, -// &self.y, -// SmartcoreSVRParameters::default() -// .with_tol(self.settings.svr_settings.as_ref().unwrap().tol) -// .with_c(self.settings.svr_settings.as_ref().unwrap().c) -// .with_eps(self.settings.svr_settings.as_ref().unwrap().c) -// .with_kernel(Kernels::sigmoid(gamma, coef)), -// self.get_kfolds(), -// metric, -// ) -// .unwrap(), -// }; -// let end = Instant::now(); -// let d = end.duration_since(start); -// self.add_model(Algorithm::SVR, cv, d); diff --git a/src/lib.rs b/src/lib.rs index 00f3722..6974dfc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,16 @@ -#![warn(clippy::all)] -#![warn(missing_docs)] -#![warn(rustdoc::missing_doc_code_examples)] -#![warn(clippy::missing_docs_in_private_items)] +#![deny(clippy::correctness)] +#![warn( + clippy::all, + clippy::suspicious, + clippy::complexity, + clippy::perf, + clippy::style, + clippy::pedantic, + clippy::nursery, + clippy::missing_docs_in_private_items +)] +#![allow(clippy::module_name_repetitions, clippy::too_many_lines)] +#![warn(missing_docs, rustdoc::missing_doc_code_examples)] #![doc = include_str!("../README.md")] pub mod settings; @@ -158,18 +167,26 @@ impl IntoLabels for Array1 { /// Trains and compares supervised models #[derive(serde::Serialize, serde::Deserialize)] pub struct SupervisedModel { + /// Settings for the model. settings: Settings, + /// The training data. x_train: DenseMatrix, + /// The training labels. y_train: Vec, + /// The validation data. x_val: DenseMatrix, + /// The validation labels. y_val: Vec, + /// The number of classes in the data. number_of_classes: usize, + /// The results of the model comparison. comparison: Vec, + /// The final model. metamodel: Model, - preprocessing: ( - Option>>, - Option>>, - ), + /// PCA model for preprocessing. + preprocessing_pca: Option>>, + /// SVD model for preprocessing. + preprocessing_svd: Option>>, } impl SupervisedModel { @@ -227,7 +244,7 @@ impl SupervisedModel { D: IntoSupervisedData, { let (x, y) = data.to_supervised_data(); - SupervisedModel::build(x, y, settings) + Self::build(x, y, settings) } /// Load the supervised model from a file saved previously @@ -241,9 +258,10 @@ impl SupervisedModel { /// let model = SupervisedModel::new_from_file("tests/load_that_model.aml"); /// # std::fs::remove_file("tests/load_that_model.aml"); /// ``` + #[must_use] pub fn new_from_file(file_name: &str) -> Self { let mut buf: Vec = Vec::new(); - std::fs::File::open(&file_name) + std::fs::File::open(file_name) .and_then(|mut f| f.read_to_end(&mut buf)) .expect("Cannot load model from file."); bincode::deserialize(&buf).expect("Can not deserialize the model") @@ -293,11 +311,12 @@ impl SupervisedModel { /// #[cfg(any(feature = "csv"))] /// model.predict("data/diabetes_without_target.csv"); /// ``` - pub fn predict(&self, x: X) -> Vec - where - X: IntoFeatures, - { - let x = &self.preprocess(x.to_dense_matrix().clone()); + /// + /// # Panics + /// + /// If the model has not been trained, this function will panic. + pub fn predict(&self, x: X) -> Vec { + let x = &self.preprocess(x.to_dense_matrix()); match self.settings.final_model_approach { FinalModel::None => panic!(""), FinalModel::Best => self.predict_by_model(x, &self.comparison[0]), @@ -321,13 +340,13 @@ impl SupervisedModel { number_of_components, } = self.settings.preprocessing { - self.train_pca(self.x_train.clone(), number_of_components); + self.train_pca(&self.x_train.clone(), number_of_components); } if let PreProcessing::ReplaceWithSVD { number_of_components, } = self.settings.preprocessing { - self.train_svd(self.x_train.clone(), number_of_components); + self.train_svd(&self.x_train.clone(), number_of_components); } // Preprocess the data @@ -511,13 +530,13 @@ impl SupervisedModel { )); } - match self.settings.final_model_approach { - FinalModel::Blending { - algorithm, - meta_training_fraction, - meta_testing_fraction, - } => self.train_blended_model(algorithm, meta_training_fraction, meta_testing_fraction), - _ => {} + if let FinalModel::Blending { + algorithm, + meta_training_fraction, + meta_testing_fraction, + } = self.settings.final_model_approach + { + self.train_blended_model(algorithm, meta_training_fraction, meta_testing_fraction); } } @@ -553,7 +572,7 @@ impl SupervisedModel { /// # std::fs::remove_file("tests/save_best.sc"); /// ``` pub fn save_best(&self, file_name: &str) { - if let FinalModel::Best = self.settings.final_model_approach { + if matches!(self.settings.final_model_approach, FinalModel::Best) { std::fs::File::create(file_name) .and_then(|mut f| f.write_all(&self.comparison[0].model)) .expect("Cannot write model to file."); @@ -563,20 +582,35 @@ impl SupervisedModel { /// Private functions go here impl SupervisedModel { + /// Build a new supervised model + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `y` - The output data + /// * `settings` - The settings for the model fn build(x: DenseMatrix, y: Vec, settings: Settings) -> Self { Self { settings, - x_train: x.clone(), - y_train: y.clone(), + x_train: x, + number_of_classes: Self::count_classes(&y), + y_train: y, x_val: DenseMatrix::new(0, 0, vec![]), y_val: vec![], - number_of_classes: Self::count_classes(&y), comparison: vec![], - preprocessing: (None, None), - metamodel: Default::default(), + metamodel: Model::default(), + preprocessing_pca: None, + preprocessing_svd: None, } } + /// Train the supervised model. + /// + /// # Arguments + /// + /// * `algo` - The algorithm to use + /// * `training_fraction` - The fraction of the data to use for training + /// * `testing_fraction` - The fraction of the data to use for testing fn train_blended_model( &mut self, algo: Algorithm, @@ -586,7 +620,7 @@ impl SupervisedModel { // Make the data let mut meta_x: Vec> = Vec::new(); for model in &self.comparison { - meta_x.push(self.predict_by_model(&self.x_val, model)) + meta_x.push(self.predict_by_model(&self.x_val, model)); } let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose(); @@ -600,17 +634,17 @@ impl SupervisedModel { // Train the model // let model = LassoRegressorWrapper::train(&x_train, &y_train, &self.settings); - let model = (*algo.get_trainer())(&x_train, &y_train, &self.settings); + let model = algo.get_trainer()(&x_train, &y_train, &self.settings); // Score the model - let train_score = (*self.settings.get_metric())( + let train_score = self.settings.get_metric()( &y_train, - &(*algo.get_predictor())(&x_train, &model, &self.settings), + &algo.get_predictor()(&x_train, &model, &self.settings), // &LassoRegressorWrapper::predict(&x_train, &model, &self.settings), ); - let test_score = (*self.settings.get_metric())( + let test_score = self.settings.get_metric()( &y_test, - &(*algo.get_predictor())(&x_test, &model, &self.settings), + &algo.get_predictor()(&x_test, &model, &self.settings), // &LassoRegressorWrapper::predict(&x_test, &model, &self.settings), ); @@ -620,17 +654,27 @@ impl SupervisedModel { train_score: vec![train_score; 1], }, name: algo, - duration: Default::default(), + duration: Duration::default(), model, }; } + /// Predict using all of the trained models. + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `algo` - The algorithm to use + /// + /// # Returns + /// + /// * The predicted values fn predict_blended_model(&self, x: &DenseMatrix, algo: Algorithm) -> Vec { // Make the data let mut meta_x: Vec> = Vec::new(); for i in 0..self.comparison.len() { let model = &self.comparison[i]; - meta_x.push(self.predict_by_model(&x, model)) + meta_x.push(self.predict_by_model(x, model)); } // @@ -638,52 +682,26 @@ impl SupervisedModel { let metamodel = &self.metamodel.model; // Train the model - (*algo.get_predictor())(&xdm, metamodel, &self.settings) + algo.get_predictor()(&xdm, metamodel, &self.settings) } + /// Predict using a single model. + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `model` - The model to use + /// + /// # Returns + /// + /// * The predicted values fn predict_by_model(&self, x: &DenseMatrix, model: &Model) -> Vec { - let saved_model = &model.model; - match model.name { - Algorithm::Linear => LinearRegressorWrapper::predict(x, saved_model, &self.settings), - Algorithm::Lasso => LassoRegressorWrapper::predict(x, saved_model, &self.settings), - Algorithm::Ridge => RidgeRegressorWrapper::predict(x, saved_model, &self.settings), - Algorithm::ElasticNet => { - ElasticNetRegressorWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::RandomForestRegressor => { - RandomForestRegressorWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::KNNRegressor => KNNRegressorWrapper::predict(x, saved_model, &self.settings), - Algorithm::SVR => { - SupportVectorRegressorWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::DecisionTreeRegressor => { - DecisionTreeRegressorWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::LogisticRegression => { - LogisticRegressionWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::RandomForestClassifier => { - RandomForestClassifierWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::DecisionTreeClassifier => { - DecisionTreeClassifierWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::KNNClassifier => { - KNNClassifierWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::SVC => { - SupportVectorClassifierWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::GaussianNaiveBayes => { - GaussianNaiveBayesClassifierWrapper::predict(x, saved_model, &self.settings) - } - Algorithm::CategoricalNaiveBayes => { - CategoricalNaiveBayesClassifierWrapper::predict(x, saved_model, &self.settings) - } - } + model.name.get_predictor()(x, &model.model, &self.settings) } + /// Get interaction features for the data. + /// + /// # Arguments fn interaction_features(mut x: DenseMatrix) -> DenseMatrix { let (_, width) = x.shape(); for i in 0..width { @@ -696,10 +714,20 @@ impl SupervisedModel { x } + /// Get polynomial features for the data. + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `order` - The order of the polynomial + /// + /// # Returns + /// + /// * The data with polynomial features fn polynomial_features(mut x: DenseMatrix, order: usize) -> DenseMatrix { let (height, width) = x.shape(); for n in 2..=order { - let combinations = (0..width).into_iter().combinations_with_replacement(n); + let combinations = (0..width).combinations_with_replacement(n); for combo in combinations { let mut feature = vec![1.0; height]; for column in combo { @@ -712,63 +740,96 @@ impl SupervisedModel { x } - fn train_pca(&mut self, x: DenseMatrix, n: usize) { + /// Train PCA on the data for preprocessing. + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `n` - The number of components to use + fn train_pca(&mut self, x: &DenseMatrix, n: usize) { let pca = PCA::fit( - &x, + x, PCAParameters::default() .with_n_components(n) .with_use_correlation_matrix(true), ) .unwrap(); - self.preprocessing.0 = Some(pca); + self.preprocessing_pca = Some(pca); } - fn pca_features(&self, x: DenseMatrix, n: usize) -> DenseMatrix { - self.preprocessing - .0 + /// Get PCA features for the data using the trained PCA preprocessor. + /// + /// # Arguments + /// + /// * `x` - The input data + fn pca_features(&self, x: &DenseMatrix, _: usize) -> DenseMatrix { + self.preprocessing_pca .as_ref() .unwrap() - .transform(&x) + .transform(x) .unwrap() } - fn train_svd(&mut self, x: DenseMatrix, n: usize) { - let svd = SVD::fit(&x, SVDParameters::default().with_n_components(n)).unwrap(); - self.preprocessing.1 = Some(svd); + /// Train SVD on the data for preprocessing. + /// + /// # Arguments + /// + /// * `x` - The input data + /// * `n` - The number of components to use + fn train_svd(&mut self, x: &DenseMatrix, n: usize) { + let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)).unwrap(); + self.preprocessing_svd = Some(svd); } - fn svd_features(&self, x: DenseMatrix, n: usize) -> DenseMatrix { - self.preprocessing - .1 + /// Get SVD features for the data. + fn svd_features(&self, x: &DenseMatrix, _: usize) -> DenseMatrix { + self.preprocessing_svd .as_ref() .unwrap() - .transform(&x) + .transform(x) .unwrap() } + /// Pre process the data. + /// + /// # Arguments + /// + /// * `x` - The input data + /// + /// # Returns + /// + /// * The preprocessed data fn preprocess(&self, x: DenseMatrix) -> DenseMatrix { match self.settings.preprocessing { PreProcessing::None => x, - PreProcessing::AddInteractions => SupervisedModel::interaction_features(x), - PreProcessing::AddPolynomial { order } => { - SupervisedModel::polynomial_features(x, order) - } + PreProcessing::AddInteractions => Self::interaction_features(x), + PreProcessing::AddPolynomial { order } => Self::polynomial_features(x, order), PreProcessing::ReplaceWithPCA { number_of_components, - } => self.pca_features(x, number_of_components), + } => self.pca_features(&x, number_of_components), PreProcessing::ReplaceWithSVD { number_of_components, - } => self.svd_features(x, number_of_components), + } => self.svd_features(&x, number_of_components), } } - fn count_classes(y: &Vec) -> usize { - let mut sorted_targets = y.clone(); - sorted_targets.sort_by(|a, b| a.partial_cmp(&b).unwrap_or(Equal)); + /// Count the number of classes in the data. + /// + /// # Arguments + /// + /// * `y` - The data to count the classes in + /// + /// # Returns + /// + /// * The number of classes + fn count_classes(y: &[f32]) -> usize { + let mut sorted_targets = y.to_vec(); + sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Equal)); sorted_targets.dedup(); sorted_targets.len() } + /// Record a model in the comparison. fn record_model(&mut self, model: (CrossValidationResult, Algorithm, Duration, Vec)) { self.comparison.push(Model { score: model.0, @@ -779,6 +840,7 @@ impl SupervisedModel { self.sort(); } + /// Sort the models in the comparison by their mean test scores. fn sort(&mut self) { self.comparison.sort_by(|a, b| { a.score @@ -848,23 +910,27 @@ impl Display for SupervisedModel { meta_table.add_row(row_vec); // Write - write!(f, "{}\n{}", table, meta_table) + write!(f, "{table}\n{meta_table}") } } /// This contains the results of a single model #[derive(serde::Serialize, serde::Deserialize)] struct Model { + /// The cross validation score of the model #[serde(with = "CrossValidationResultDef")] score: CrossValidationResult, + /// The algorithm used name: Algorithm, + /// The time it took to train the model duration: Duration, + /// What is this? TODO model: Vec, } impl Default for Model { fn default() -> Self { - Model { + Self { score: CrossValidationResult { test_score: vec![], train_score: vec![], @@ -876,6 +942,7 @@ impl Default for Model { } } +/// This is a wrapper for the `CrossValidationResult` #[derive(serde::Serialize, serde::Deserialize)] #[serde(remote = "CrossValidationResult::")] struct CrossValidationResultDef { diff --git a/src/settings/knn_classifier_parameters.rs b/src/settings/knn_classifier_parameters.rs index f94fef6..f79e704 100644 --- a/src/settings/knn_classifier_parameters.rs +++ b/src/settings/knn_classifier_parameters.rs @@ -1,36 +1,46 @@ +//! KNN classifier parameters + use crate::utils::Distance; pub use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; /// Parameters for k-nearest neighbors (KNN) classification #[derive(serde::Serialize, serde::Deserialize)] pub struct KNNClassifierParameters { + /// Number of nearest neighbors to use pub(crate) k: usize, + /// Weighting function to use with KNN regression pub(crate) weight: KNNWeightFunction, + /// Search algorithm to use with KNN regression pub(crate) algorithm: KNNAlgorithmName, + /// Distance metric to use with KNN regression pub(crate) distance: Distance, } impl KNNClassifierParameters { /// Define the number of nearest neighbors to use - pub fn with_k(mut self, k: usize) -> Self { + #[must_use] + pub const fn with_k(mut self, k: usize) -> Self { self.k = k; self } /// Define the weighting function to use with KNN regression - pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self { + #[must_use] + pub const fn with_weight(mut self, weight: KNNWeightFunction) -> Self { self.weight = weight; self } /// Define the search algorithm to use with KNN regression - pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { + #[must_use] + pub const fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { self.algorithm = algorithm; self } /// Define the distance metric to use with KNN regression - pub fn with_distance(mut self, distance: Distance) -> Self { + #[must_use] + pub const fn with_distance(mut self, distance: Distance) -> Self { self.distance = distance; self } diff --git a/src/settings/knn_regressor_parameters.rs b/src/settings/knn_regressor_parameters.rs index da1e6c6..89bd78b 100644 --- a/src/settings/knn_regressor_parameters.rs +++ b/src/settings/knn_regressor_parameters.rs @@ -1,36 +1,46 @@ +//! KNN regressor parameters + use crate::utils::Distance; pub use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; /// Parameters for k-nearest neighbor (KNN) regression #[derive(serde::Serialize, serde::Deserialize)] pub struct KNNRegressorParameters { + /// Number of nearest neighbors to use pub(crate) k: usize, + /// Weighting function to use with KNN regression pub(crate) weight: KNNWeightFunction, + /// Search algorithm to use with KNN regression pub(crate) algorithm: KNNAlgorithmName, + /// Distance metric to use with KNN regression pub(crate) distance: Distance, } impl KNNRegressorParameters { /// Define the number of nearest neighbors to use - pub fn with_k(mut self, k: usize) -> Self { + #[must_use] + pub const fn with_k(mut self, k: usize) -> Self { self.k = k; self } /// Define the weighting function to use with KNN regression - pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self { + #[must_use] + pub const fn with_weight(mut self, weight: KNNWeightFunction) -> Self { self.weight = weight; self } /// Define the search algorithm to use with KNN regression - pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { + #[must_use] + pub const fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { self.algorithm = algorithm; self } /// Define the distance metric to use with KNN regression - pub fn with_distance(mut self, distance: Distance) -> Self { + #[must_use] + pub const fn with_distance(mut self, distance: Distance) -> Self { self.distance = distance; self } diff --git a/src/settings/mod.rs b/src/settings/mod.rs index b359009..96cbdf4 100644 --- a/src/settings/mod.rs +++ b/src/settings/mod.rs @@ -201,7 +201,7 @@ pub use settings_struct::Settings; /// Metrics for evaluating algorithms #[non_exhaustive] -#[derive(PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum Metric { /// Sort by R^2 RSquared, @@ -218,17 +218,17 @@ pub enum Metric { impl Display for Metric { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Metric::RSquared => write!(f, "R^2"), - Metric::MeanAbsoluteError => write!(f, "MAE"), - Metric::MeanSquaredError => write!(f, "MSE"), - Metric::Accuracy => write!(f, "Accuracy"), - Metric::None => panic!("A metric must be set."), + Self::RSquared => write!(f, "R^2"), + Self::MeanAbsoluteError => write!(f, "MAE"), + Self::MeanSquaredError => write!(f, "MSE"), + Self::Accuracy => write!(f, "Accuracy"), + Self::None => panic!("A metric must be set."), } } } /// Algorithm options -#[derive(PartialEq, Copy, Clone, serde::Serialize, serde::Deserialize)] +#[derive(PartialEq, Eq, Copy, Clone, serde::Serialize, serde::Deserialize)] pub enum Algorithm { /// Decision tree regressor DecisionTreeRegressor, @@ -263,51 +263,45 @@ pub enum Algorithm { } impl Algorithm { - pub(crate) fn get_predictor( - &self, - ) -> Box, &Vec, &Settings) -> Vec> { + /// Get the `predict` method for the underlying algorithm. + pub(crate) fn get_predictor(self) -> fn(&DenseMatrix, &Vec, &Settings) -> Vec { match self { - Algorithm::Linear => Box::new(LinearRegressorWrapper::predict), - Algorithm::Lasso => Box::new(LassoRegressorWrapper::predict), - Algorithm::Ridge => Box::new(RidgeRegressorWrapper::predict), - Algorithm::ElasticNet => Box::new(ElasticNetRegressorWrapper::predict), - Algorithm::RandomForestRegressor => Box::new(RandomForestRegressorWrapper::predict), - Algorithm::KNNRegressor => Box::new(KNNRegressorWrapper::predict), - Algorithm::SVR => Box::new(SupportVectorRegressorWrapper::predict), - Algorithm::DecisionTreeRegressor => Box::new(DecisionTreeRegressorWrapper::predict), - Algorithm::LogisticRegression => Box::new(LogisticRegressionWrapper::predict), - Algorithm::RandomForestClassifier => Box::new(RandomForestClassifierWrapper::predict), - Algorithm::DecisionTreeClassifier => Box::new(DecisionTreeClassifierWrapper::predict), - Algorithm::KNNClassifier => Box::new(KNNClassifierWrapper::predict), - Algorithm::SVC => Box::new(SupportVectorClassifierWrapper::predict), - Algorithm::GaussianNaiveBayes => Box::new(GaussianNaiveBayesClassifierWrapper::predict), - Algorithm::CategoricalNaiveBayes => { - Box::new(CategoricalNaiveBayesClassifierWrapper::predict) - } + Self::Linear => LinearRegressorWrapper::predict, + Self::Lasso => LassoRegressorWrapper::predict, + Self::Ridge => RidgeRegressorWrapper::predict, + Self::ElasticNet => ElasticNetRegressorWrapper::predict, + Self::RandomForestRegressor => RandomForestRegressorWrapper::predict, + Self::KNNRegressor => KNNRegressorWrapper::predict, + Self::SVR => SupportVectorRegressorWrapper::predict, + Self::DecisionTreeRegressor => DecisionTreeRegressorWrapper::predict, + Self::LogisticRegression => LogisticRegressionWrapper::predict, + Self::RandomForestClassifier => RandomForestClassifierWrapper::predict, + Self::DecisionTreeClassifier => DecisionTreeClassifierWrapper::predict, + Self::KNNClassifier => KNNClassifierWrapper::predict, + Self::SVC => SupportVectorClassifierWrapper::predict, + Self::GaussianNaiveBayes => GaussianNaiveBayesClassifierWrapper::predict, + Self::CategoricalNaiveBayes => CategoricalNaiveBayesClassifierWrapper::predict, } } - pub(crate) fn get_trainer( - &self, - ) -> Box, &Vec, &Settings) -> Vec> { + /// Get the `train` method for the underlying algorithm. + pub(crate) fn get_trainer(self) -> fn(&DenseMatrix, &Vec, &Settings) -> Vec { match self { - Algorithm::Linear => Box::new(LinearRegressorWrapper::train), - Algorithm::Lasso => Box::new(LassoRegressorWrapper::train), - Algorithm::Ridge => Box::new(RidgeRegressorWrapper::train), - Algorithm::ElasticNet => Box::new(ElasticNetRegressorWrapper::train), - Algorithm::RandomForestRegressor => Box::new(RandomForestRegressorWrapper::train), - Algorithm::KNNRegressor => Box::new(KNNRegressorWrapper::train), - Algorithm::SVR => Box::new(SupportVectorRegressorWrapper::train), - Algorithm::DecisionTreeRegressor => Box::new(DecisionTreeRegressorWrapper::train), - Algorithm::LogisticRegression => Box::new(LogisticRegressionWrapper::train), - Algorithm::RandomForestClassifier => Box::new(RandomForestClassifierWrapper::train), - Algorithm::DecisionTreeClassifier => Box::new(DecisionTreeClassifierWrapper::train), - Algorithm::KNNClassifier => Box::new(KNNClassifierWrapper::train), - Algorithm::SVC => Box::new(SupportVectorClassifierWrapper::train), - Algorithm::GaussianNaiveBayes => Box::new(GaussianNaiveBayesClassifierWrapper::train), - Algorithm::CategoricalNaiveBayes => { - Box::new(CategoricalNaiveBayesClassifierWrapper::train) - } + Self::Linear => LinearRegressorWrapper::train, + Self::Lasso => LassoRegressorWrapper::train, + Self::Ridge => RidgeRegressorWrapper::train, + Self::ElasticNet => ElasticNetRegressorWrapper::train, + Self::RandomForestRegressor => RandomForestRegressorWrapper::train, + Self::KNNRegressor => KNNRegressorWrapper::train, + Self::SVR => SupportVectorRegressorWrapper::train, + Self::DecisionTreeRegressor => DecisionTreeRegressorWrapper::train, + Self::LogisticRegression => LogisticRegressionWrapper::train, + Self::RandomForestClassifier => RandomForestClassifierWrapper::train, + Self::DecisionTreeClassifier => DecisionTreeClassifierWrapper::train, + Self::KNNClassifier => KNNClassifierWrapper::train, + Self::SVC => SupportVectorClassifierWrapper::train, + Self::GaussianNaiveBayes => GaussianNaiveBayesClassifierWrapper::train, + Self::CategoricalNaiveBayes => CategoricalNaiveBayesClassifierWrapper::train, } } } @@ -315,21 +309,21 @@ impl Algorithm { impl Display for Algorithm { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Algorithm::DecisionTreeRegressor => write!(f, "Decision Tree Regressor"), - Algorithm::KNNRegressor => write!(f, "KNN Regressor"), - Algorithm::RandomForestRegressor => write!(f, "Random Forest Regressor"), - Algorithm::Linear => write!(f, "Linear Regressor"), - Algorithm::Ridge => write!(f, "Ridge Regressor"), - Algorithm::Lasso => write!(f, "LASSO Regressor"), - Algorithm::ElasticNet => write!(f, "Elastic Net Regressor"), - Algorithm::SVR => write!(f, "Support Vector Regressor"), - Algorithm::DecisionTreeClassifier => write!(f, "Decision Tree Classifier"), - Algorithm::KNNClassifier => write!(f, "KNN Classifier"), - Algorithm::RandomForestClassifier => write!(f, "Random Forest Classifier"), - Algorithm::LogisticRegression => write!(f, "Logistic Regression Classifier"), - Algorithm::SVC => write!(f, "Support Vector Classifier"), - Algorithm::GaussianNaiveBayes => write!(f, "Gaussian Naive Bayes"), - Algorithm::CategoricalNaiveBayes => write!(f, "Categorical Naive Bayes"), + Self::DecisionTreeRegressor => write!(f, "Decision Tree Regressor"), + Self::KNNRegressor => write!(f, "KNN Regressor"), + Self::RandomForestRegressor => write!(f, "Random Forest Regressor"), + Self::Linear => write!(f, "Linear Regressor"), + Self::Ridge => write!(f, "Ridge Regressor"), + Self::Lasso => write!(f, "LASSO Regressor"), + Self::ElasticNet => write!(f, "Elastic Net Regressor"), + Self::SVR => write!(f, "Support Vector Regressor"), + Self::DecisionTreeClassifier => write!(f, "Decision Tree Classifier"), + Self::KNNClassifier => write!(f, "KNN Classifier"), + Self::RandomForestClassifier => write!(f, "Random Forest Classifier"), + Self::LogisticRegression => write!(f, "Logistic Regression Classifier"), + Self::SVC => write!(f, "Support Vector Classifier"), + Self::GaussianNaiveBayes => write!(f, "Gaussian Naive Bayes"), + Self::CategoricalNaiveBayes => write!(f, "Categorical Naive Bayes"), } } } @@ -361,26 +355,18 @@ pub enum PreProcessing { impl Display for PreProcessing { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - PreProcessing::None => write!(f, "None"), - PreProcessing::AddInteractions => write!(f, "Interaction terms added"), - PreProcessing::AddPolynomial { order } => { - write!(f, "Polynomial terms added (order = {})", order) + Self::None => write!(f, "None"), + Self::AddInteractions => write!(f, "Interaction terms added"), + Self::AddPolynomial { order } => { + write!(f, "Polynomial terms added (order = {order})") } - PreProcessing::ReplaceWithPCA { + Self::ReplaceWithPCA { number_of_components, - } => write!( - f, - "Replaced with PCA features (n = {})", - number_of_components - ), + } => write!(f, "Replaced with PCA features (n = {number_of_components})"), - PreProcessing::ReplaceWithSVD { + Self::ReplaceWithSVD { number_of_components, - } => write!( - f, - "Replaced with SVD features (n = {})", - number_of_components - ), + } => write!(f, "Replaced with SVD features (n = {number_of_components})"), } } } @@ -410,7 +396,8 @@ pub enum FinalModel { impl FinalModel { /// Default values for a blending model (linear regression, 30% of all data reserved for training the blending model) - pub fn default_blending() -> FinalModel { + #[must_use] + pub const fn default_blending() -> Self { Self::Blending { algorithm: Algorithm::Linear, meta_training_fraction: 0.15, diff --git a/src/settings/settings_struct.rs b/src/settings/settings_struct.rs index 2c37370..c5e2206 100644 --- a/src/settings/settings_struct.rs +++ b/src/settings/settings_struct.rs @@ -1,3 +1,5 @@ +//! Settings for the automl crate + use comfy_table::{ modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, Attribute, Cell, Table, }; @@ -25,36 +27,61 @@ use std::fmt::{Display, Formatter}; use std::io::{Read, Write}; /// Settings for supervised models +/// +/// Any algorithms in the `skiplist` member will be skipped during training. #[derive(serde::Serialize, serde::Deserialize)] pub struct Settings { + /// The metric to sort by pub(crate) sort_by: Metric, + /// The type of model to train model_type: ModelType, + /// The algorithms to skip pub(crate) skiplist: Vec, + /// The number of folds for cross-validation number_of_folds: usize, + /// Whether or not to shuffle the data pub(crate) shuffle: bool, + /// Whether or not to be verbose verbose: bool, + /// The approach to use for the final model pub(crate) final_model_approach: FinalModel, + /// The kind of preprocessing to perform pub(crate) preprocessing: PreProcessing, + /// Optional settings for linear regression pub(crate) linear_settings: Option, + /// Optional settings for support vector regressor pub(crate) svr_settings: Option, + /// Optional settings for lasso regression pub(crate) lasso_settings: Option>, + /// Optional settings for ridge regression pub(crate) ridge_settings: Option>, + /// Optional settings for elastic net pub(crate) elastic_net_settings: Option>, + /// Optional settings for decision tree regressor pub(crate) decision_tree_regressor_settings: Option, + /// Optional settings for random forest regressor pub(crate) random_forest_regressor_settings: Option, + /// Optional settings for KNN regressor pub(crate) knn_regressor_settings: Option, + /// Optional settings for logistic regression pub(crate) logistic_settings: Option>, + /// Optional settings for random forest pub(crate) random_forest_classifier_settings: Option, + /// Optional settings for KNN classifier pub(crate) knn_classifier_settings: Option, + /// Optional settings for support vector classifier pub(crate) svc_settings: Option, + /// Optional settings for decision tree classifier pub(crate) decision_tree_classifier_settings: Option, + /// Optional settings for Gaussian Naive Bayes pub(crate) gaussian_nb_settings: Option>, + /// Optional settings for Categorical Naive Bayes pub(crate) categorical_nb_settings: Option>, } impl Default for Settings { fn default() -> Self { - Settings { + Self { sort_by: Metric::RSquared, model_type: ModelType::None, final_model_approach: FinalModel::Best, @@ -99,20 +126,22 @@ impl Default for Settings { } impl Settings { + /// Get the k-fold cross-validator pub(crate) fn get_kfolds(&self) -> KFold { KFold::default() .with_n_splits(self.number_of_folds) .with_shuffle(self.shuffle) } - pub(crate) fn get_metric(&self) -> Box, &Vec) -> f32> { - Box::new(match self.sort_by { + /// Get the metric to sort by + pub(crate) fn get_metric(&self) -> fn(&Vec, &Vec) -> f32 { + match self.sort_by { Metric::RSquared => r2, Metric::MeanAbsoluteError => mean_absolute_error, Metric::MeanSquaredError => mean_squared_error, Metric::Accuracy => accuracy, Metric::None => panic!("A metric must be set."), - }) + } } /// Creates default settings for regression @@ -120,8 +149,9 @@ impl Settings { /// # use automl::Settings; /// let settings = Settings::default_regression(); /// ``` + #[must_use] pub fn default_regression() -> Self { - Settings { + Self { sort_by: Metric::RSquared, model_type: ModelType::Regression, final_model_approach: FinalModel::Best, @@ -161,8 +191,9 @@ impl Settings { /// # use automl::Settings; /// let settings = Settings::default_classification(); /// ``` + #[must_use] pub fn default_classification() -> Self { - Settings { + Self { sort_by: Metric::Accuracy, model_type: ModelType::Classification, final_model_approach: FinalModel::Best, @@ -206,12 +237,13 @@ impl Settings { /// let settings = Settings::new_from_file("tests/load_those_settings.yaml"); /// # std::fs::remove_file("tests/load_those_settings.yaml"); /// ``` + #[must_use] pub fn new_from_file(file_name: &str) -> Self { let mut buf: Vec = Vec::new(); - std::fs::File::open(&file_name) + std::fs::File::open(file_name) .and_then(|mut f| f.read_to_end(&mut buf)) .expect("Cannot read settings file."); - serde_yaml::from_slice(&*buf).expect("Cannot deserialize settings file.") + serde_yaml::from_slice(&buf).expect("Cannot deserialize settings file.") } /// Save the current settings to a file for later use @@ -224,8 +256,8 @@ impl Settings { pub fn save(&self, file_name: &str) { let serial = serde_yaml::to_string(&self).expect("Cannot serialize settings."); std::fs::File::create(file_name) - .and_then(|mut f| f.write_all((&serial).as_ref())) - .expect("Cannot write settings to file.") + .and_then(|mut f| f.write_all(serial.as_ref())) + .expect("Cannot write settings to file."); } /// Specify number of folds for cross-validation @@ -233,7 +265,8 @@ impl Settings { /// # use automl::Settings; /// let settings = Settings::default().with_number_of_folds(3); /// ``` - pub fn with_number_of_folds(mut self, n: usize) -> Self { + #[must_use] + pub const fn with_number_of_folds(mut self, n: usize) -> Self { self.number_of_folds = n; self } @@ -243,7 +276,8 @@ impl Settings { /// # use automl::Settings; /// let settings = Settings::default().shuffle_data(true); /// ``` - pub fn shuffle_data(mut self, shuffle: bool) -> Self { + #[must_use] + pub const fn shuffle_data(mut self, shuffle: bool) -> Self { self.shuffle = shuffle; self } @@ -253,7 +287,8 @@ impl Settings { /// # use automl::Settings; /// let settings = Settings::default().verbose(true); /// ``` - pub fn verbose(mut self, verbose: bool) -> Self { + #[must_use] + pub const fn verbose(mut self, verbose: bool) -> Self { self.verbose = verbose; self } @@ -264,7 +299,8 @@ impl Settings { /// use automl::settings::PreProcessing; /// let settings = Settings::default().with_preprocessing(PreProcessing::AddInteractions); /// ``` - pub fn with_preprocessing(mut self, pre: PreProcessing) -> Self { + #[must_use] + pub const fn with_preprocessing(mut self, pre: PreProcessing) -> Self { self.preprocessing = pre; self } @@ -275,7 +311,8 @@ impl Settings { /// use automl::settings::FinalModel; /// let settings = Settings::default().with_final_model(FinalModel::Best); /// ``` - pub fn with_final_model(mut self, approach: FinalModel) -> Self { + #[must_use] + pub const fn with_final_model(mut self, approach: FinalModel) -> Self { self.final_model_approach = approach; self } @@ -286,6 +323,7 @@ impl Settings { /// use automl::settings::Algorithm; /// let settings = Settings::default().skip(Algorithm::RandomForestRegressor); /// ``` + #[must_use] pub fn skip(mut self, skip: Algorithm) -> Self { self.skiplist.push(skip); self @@ -297,6 +335,7 @@ impl Settings { /// use automl::settings::Algorithm; /// let settings = Settings::default().only(Algorithm::RandomForestRegressor); /// ``` + #[must_use] pub fn only(mut self, only: Algorithm) -> Self { self.skiplist = Self::default().skiplist; self.skiplist.retain(|&algo| algo != only); @@ -309,12 +348,13 @@ impl Settings { /// use automl::settings::Metric; /// let settings = Settings::default().sorted_by(Metric::RSquared); /// ``` - pub fn sorted_by(mut self, sort_by: Metric) -> Self { + #[must_use] + pub const fn sorted_by(mut self, sort_by: Metric) -> Self { self.sort_by = sort_by; self } - /// Specify settings for random_forest + /// Specify settings for Random Forest Classifier /// ``` /// # use automl::Settings; /// use automl::settings::RandomForestClassifierParameters; @@ -327,7 +367,8 @@ impl Settings { /// .with_min_samples_split(20) /// ); /// ``` - pub fn with_random_forest_classifier_settings( + #[must_use] + pub const fn with_random_forest_classifier_settings( mut self, settings: RandomForestClassifierParameters, ) -> Self { @@ -342,7 +383,11 @@ impl Settings { /// let settings = Settings::default() /// .with_logistic_settings(LogisticRegressionParameters::default()); /// ``` - pub fn with_logistic_settings(mut self, settings: LogisticRegressionParameters) -> Self { + #[must_use] + pub const fn with_logistic_settings( + mut self, + settings: LogisticRegressionParameters, + ) -> Self { self.logistic_settings = Some(settings); self } @@ -359,7 +404,8 @@ impl Settings { /// .with_kernel(Kernel::Linear) /// ); /// ``` - pub fn with_svc_settings(mut self, settings: SVCParameters) -> Self { + #[must_use] + pub const fn with_svc_settings(mut self, settings: SVCParameters) -> Self { self.svc_settings = Some(settings); self } @@ -375,7 +421,8 @@ impl Settings { /// .with_min_samples_leaf(20) /// ); /// ``` - pub fn with_decision_tree_classifier_settings( + #[must_use] + pub const fn with_decision_tree_classifier_settings( mut self, settings: DecisionTreeClassifierParameters, ) -> Self { @@ -396,7 +443,8 @@ impl Settings { /// .with_weight(KNNWeightFunction::Uniform) /// ); /// ``` - pub fn with_knn_classifier_settings(mut self, settings: KNNClassifierParameters) -> Self { + #[must_use] + pub const fn with_knn_classifier_settings(mut self, settings: KNNClassifierParameters) -> Self { self.knn_classifier_settings = Some(settings); self } @@ -410,6 +458,8 @@ impl Settings { /// .with_priors(vec![1.0, 1.0]) /// ); /// ``` + #[allow(clippy::missing_const_for_fn)] + #[must_use] pub fn with_gaussian_nb_settings(mut self, settings: GaussianNBParameters) -> Self { self.gaussian_nb_settings = Some(settings); self @@ -424,7 +474,11 @@ impl Settings { /// .with_alpha(1.0) /// ); /// ``` - pub fn with_categorical_nb_settings(mut self, settings: CategoricalNBParameters) -> Self { + #[must_use] + pub const fn with_categorical_nb_settings( + mut self, + settings: CategoricalNBParameters, + ) -> Self { self.categorical_nb_settings = Some(settings); self } @@ -438,7 +492,8 @@ impl Settings { /// .with_solver(LinearRegressionSolverName::QR) /// ); /// ``` - pub fn with_linear_settings(mut self, settings: LinearRegressionParameters) -> Self { + #[must_use] + pub const fn with_linear_settings(mut self, settings: LinearRegressionParameters) -> Self { self.linear_settings = Some(settings); self } @@ -455,7 +510,8 @@ impl Settings { /// .with_max_iter(10_000) /// ); /// ``` - pub fn with_lasso_settings(mut self, settings: LassoParameters) -> Self { + #[must_use] + pub const fn with_lasso_settings(mut self, settings: LassoParameters) -> Self { self.lasso_settings = Some(settings); self } @@ -471,7 +527,8 @@ impl Settings { /// .with_solver(RidgeRegressionSolverName::Cholesky) /// ); /// ``` - pub fn with_ridge_settings(mut self, settings: RidgeRegressionParameters) -> Self { + #[must_use] + pub const fn with_ridge_settings(mut self, settings: RidgeRegressionParameters) -> Self { self.ridge_settings = Some(settings); self } @@ -489,7 +546,8 @@ impl Settings { /// .with_l1_ratio(0.5) /// ); /// ``` - pub fn with_elastic_net_settings(mut self, settings: ElasticNetParameters) -> Self { + #[must_use] + pub const fn with_elastic_net_settings(mut self, settings: ElasticNetParameters) -> Self { self.elastic_net_settings = Some(settings); self } @@ -507,7 +565,8 @@ impl Settings { /// .with_weight(KNNWeightFunction::Uniform) /// ); /// ``` - pub fn with_knn_regressor_settings(mut self, settings: KNNRegressorParameters) -> Self { + #[must_use] + pub const fn with_knn_regressor_settings(mut self, settings: KNNRegressorParameters) -> Self { self.knn_regressor_settings = Some(settings); self } @@ -524,7 +583,8 @@ impl Settings { /// .with_kernel(Kernel::Linear) /// ); /// ``` - pub fn with_svr_settings(mut self, settings: SVRParameters) -> Self { + #[must_use] + pub const fn with_svr_settings(mut self, settings: SVRParameters) -> Self { self.svr_settings = Some(settings); self } @@ -542,7 +602,8 @@ impl Settings { /// .with_min_samples_split(20) /// ); /// ``` - pub fn with_random_forest_regressor_settings( + #[must_use] + pub const fn with_random_forest_regressor_settings( mut self, settings: RandomForestRegressorParameters, ) -> Self { @@ -561,7 +622,8 @@ impl Settings { /// .with_min_samples_leaf(20) /// ); /// ``` - pub fn with_decision_tree_regressor_settings( + #[must_use] + pub const fn with_decision_tree_regressor_settings( mut self, settings: DecisionTreeRegressorParameters, ) -> Self { @@ -577,11 +639,11 @@ impl Display for Settings { // Get list of algorithms to skip let mut skiplist = String::new(); - if self.skiplist.len() == 0 { + if self.skiplist.is_empty() { skiplist.push_str("None "); } else { for algorithm_to_skip in &self.skiplist { - skiplist.push_str(&*format!("{}\n", algorithm_to_skip)); + skiplist.push_str(&format!("{algorithm_to_skip}\n")); } } @@ -608,7 +670,7 @@ impl Display for Settings { ]) .add_row(vec![ " Skipped Algorithms", - &*format!("{}", &skiplist[0..skiplist.len() - 1]), + &skiplist[0..skiplist.len() - 1], ]); if !self.skiplist.contains(&Algorithm::Linear) { table @@ -792,20 +854,14 @@ impl Display for Settings { ]) .add_row(vec![ " Search algorithm", - &*format!( - "{}", - print_knn_search_algorithm( - &self.knn_regressor_settings.as_ref().unwrap().algorithm - ) + &print_knn_search_algorithm( + &self.knn_regressor_settings.as_ref().unwrap().algorithm, ), ]) .add_row(vec![ " Weighting function", - &*format!( - "{}", - print_knn_weight_function( - &self.knn_regressor_settings.as_ref().unwrap().weight - ) + &print_knn_weight_function( + &self.knn_regressor_settings.as_ref().unwrap().weight, ), ]) .add_row(vec![ @@ -922,20 +978,14 @@ impl Display for Settings { ]) .add_row(vec![ " Search algorithm", - &*format!( - "{}", - print_knn_search_algorithm( - &self.knn_classifier_settings.as_ref().unwrap().algorithm - ) + &print_knn_search_algorithm( + &self.knn_classifier_settings.as_ref().unwrap().algorithm, ), ]) .add_row(vec![ " Weighting function", - &*format!( - "{}", - print_knn_weight_function( - &self.knn_classifier_settings.as_ref().unwrap().weight - ) + &print_knn_weight_function( + &self.knn_classifier_settings.as_ref().unwrap().weight, ), ]) .add_row(vec![ @@ -1041,23 +1091,27 @@ impl Display for Settings { ]); } - write!(f, "{}\n", table) + writeln!(f, "{table}") } } +/// Model type to train #[derive(serde::Serialize, serde::Deserialize)] enum ModelType { + /// No model type specified None, + /// Regression model Regression, + /// Classification model Classification, } impl Display for ModelType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - ModelType::None => write!(f, "None"), - ModelType::Regression => write!(f, "Regression"), - ModelType::Classification => write!(f, "Classification"), + Self::None => write!(f, "None"), + Self::Regression => write!(f, "Regression"), + Self::Classification => write!(f, "Classification"), } } } diff --git a/src/settings/svc_parameters.rs b/src/settings/svc_parameters.rs index d2d8db4..9199393 100644 --- a/src/settings/svc_parameters.rs +++ b/src/settings/svc_parameters.rs @@ -1,35 +1,45 @@ +//! Support Vector Classification parameters + pub use crate::utils::Kernel; /// Parameters for support vector classification #[derive(serde::Serialize, serde::Deserialize)] pub struct SVCParameters { + /// Number of epochs to use in the epsilon-SVC model pub(crate) epoch: usize, + /// Regulation penalty to use with the SVC model pub(crate) c: f32, + /// Convergence tolerance to use with the SVC model pub(crate) tol: f32, + /// Kernel to use with the SVC model pub(crate) kernel: Kernel, } impl SVCParameters { - /// Define the value of epsilon to use in the epsilon-SVR model. - pub fn with_epoch(mut self, epoch: usize) -> Self { + /// Define the number of epochs to use in the epsilon-SVC model. + #[must_use] + pub const fn with_epoch(mut self, epoch: usize) -> Self { self.epoch = epoch; self } - /// Define the regulation penalty to use with the SVR Model - pub fn with_c(mut self, c: f32) -> Self { + /// Define the regulation penalty to use with the SVC Model + #[must_use] + pub const fn with_c(mut self, c: f32) -> Self { self.c = c; self } - /// Define the convergence tolerance to use with the SVR model - pub fn with_tol(mut self, tol: f32) -> Self { + /// Define the convergence tolerance to use with the SVC model + #[must_use] + pub const fn with_tol(mut self, tol: f32) -> Self { self.tol = tol; self } - /// Define which kernel to use with the SVR model - pub fn with_kernel(mut self, kernel: Kernel) -> Self { + /// Define which kernel to use with the SVC model + #[must_use] + pub const fn with_kernel(mut self, kernel: Kernel) -> Self { self.kernel = kernel; self } diff --git a/src/settings/svr_parameters.rs b/src/settings/svr_parameters.rs index f2bb76b..2bc424d 100644 --- a/src/settings/svr_parameters.rs +++ b/src/settings/svr_parameters.rs @@ -1,35 +1,45 @@ +//! Support Vector Regression parameters + pub use crate::utils::Kernel; /// Parameters for support vector regression #[derive(serde::Serialize, serde::Deserialize)] pub struct SVRParameters { + /// Epsilon in the epsilon-SVR model. pub(crate) eps: f32, + /// Regularization parameter. pub(crate) c: f32, + /// Tolerance for stopping criterion. pub(crate) tol: f32, + /// Kernel to use for the SVR model pub(crate) kernel: Kernel, } impl SVRParameters { /// Define the value of epsilon to use in the epsilon-SVR model. - pub fn with_eps(mut self, eps: f32) -> Self { + #[must_use] + pub const fn with_eps(mut self, eps: f32) -> Self { self.eps = eps; self } /// Define the regulation penalty to use with the SVR Model - pub fn with_c(mut self, c: f32) -> Self { + #[must_use] + pub const fn with_c(mut self, c: f32) -> Self { self.c = c; self } /// Define the convergence tolerance to use with the SVR model - pub fn with_tol(mut self, tol: f32) -> Self { + #[must_use] + pub const fn with_tol(mut self, tol: f32) -> Self { self.tol = tol; self } /// Define which kernel to use with the SVR model - pub fn with_kernel(mut self, kernel: Kernel) -> Self { + #[must_use] + pub const fn with_kernel(mut self, kernel: Kernel) -> Self { self.kernel = kernel; self } diff --git a/src/utils.rs b/src/utils.rs index 91ea5d9..b02cbc7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,27 +1,28 @@ +//! Utility functions for the crate. + use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; use std::fmt::{Debug, Display, Formatter}; -pub(crate) fn print_option(x: Option) -> String { - match x { - None => "None".to_string(), - Some(y) => format!("{}", y), - } +/// Convert an Option to a String for printing in display mode. +pub fn print_option(x: Option) -> String { + x.map_or_else(|| "None".to_string(), |y| format!("{y}")) } -pub(crate) fn debug_option(x: Option) -> String { - match x { - None => "None".to_string(), - Some(y) => format!("{:#?}", y), - } + +/// Convert an Option to a String for printing in debug mode. +pub fn debug_option(x: Option) -> String { + x.map_or_else(|| "None".to_string(), |y| format!("{y:#?}")) } -pub(crate) fn print_knn_weight_function(f: &KNNWeightFunction) -> String { +/// Get the name for a knn weight function. +pub fn print_knn_weight_function(f: &KNNWeightFunction) -> String { match f { KNNWeightFunction::Uniform => "Uniform".to_string(), KNNWeightFunction::Distance => "Distance".to_string(), } } -pub(crate) fn print_knn_search_algorithm(a: &KNNAlgorithmName) -> String { +/// Get the name for a knn search algorithm. +pub fn print_knn_search_algorithm(a: &KNNAlgorithmName) -> String { match a { KNNAlgorithmName::LinearSearch => "Linear Search".to_string(), KNNAlgorithmName::CoverTree => "Cover Tree".to_string(), @@ -47,15 +48,14 @@ pub enum Kernel { impl Display for Kernel { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Kernel::Linear => write!(f, "Linear"), - Kernel::Polynomial(degree, gamma, coef) => write!( + Self::Linear => write!(f, "Linear"), + Self::Polynomial(degree, gamma, coef) => write!( f, - "Polynomial\n degree = {}\n gamma = {}\n coef = {}", - degree, gamma, coef + "Polynomial\n degree = {degree}\n gamma = {gamma}\n coef = {coef}" ), - Kernel::RBF(gamma) => write!(f, "RBF\n gamma = {}", gamma), - Kernel::Sigmoid(gamma, coef) => { - write!(f, "Sigmoid\n gamma = {}\n coef = {}", gamma, coef) + Self::RBF(gamma) => write!(f, "RBF\n gamma = {gamma}"), + Self::Sigmoid(gamma, coef) => { + write!(f, "Sigmoid\n gamma = {gamma}\n coef = {coef}") } } } @@ -83,74 +83,73 @@ pub enum Distance { impl Display for Distance { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Distance::Euclidean => write!(f, "Euclidean"), - Distance::Manhattan => write!(f, "Manhattan"), - Distance::Minkowski(n) => write!(f, "Minkowski\n p = {}", n), - Distance::Mahalanobis => write!(f, "Mahalanobis"), - Distance::Hamming => write!(f, "Hamming"), + Self::Euclidean => write!(f, "Euclidean"), + Self::Manhattan => write!(f, "Manhattan"), + Self::Minkowski(n) => write!(f, "Minkowski\n p = {n}"), + Self::Mahalanobis => write!(f, "Mahalanobis"), + Self::Hamming => write!(f, "Hamming"), } } } /// Function to do element-wise multiplication fo two vectors -pub fn elementwise_multiply(v1: &Vec, v2: &Vec) -> Vec { +pub fn elementwise_multiply(v1: &[f32], v2: &[f32]) -> Vec { v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect() } #[cfg(any(feature = "csv"))] -use polars::prelude::{ - BooleanChunked, BooleanChunkedBuilder, CsvReader, DataFrame, DataType, NamedFrom, PolarsError, - SerReader, Series, -}; +use polars::prelude::{CsvReader, DataFrame, PolarsError, SerReader}; #[cfg(any(feature = "csv"))] -pub(crate) fn validate_and_read

(file_path: P) -> DataFrame +/// Read and validate a csv file or URL into a polars `DataFrame`. +pub fn validate_and_read

(file_path: P) -> DataFrame where P: AsRef, { let file_path_as_str = file_path.as_ref().to_str().unwrap(); - match CsvReader::from_path(file_path_as_str) { - Ok(csv) => csv - .infer_schema(Some(10)) - .has_header( - csv_sniffer::Sniffer::new() - .sniff_path(file_path_as_str.clone()) - .expect("Cannot sniff file") - .dialect - .header - .has_header_row, - ) - .finish() - .expect("Cannot read file as CSV") - .drop_nulls(None) - .expect("Cannot remove null values") - .convert_to_float() - .expect("Cannot convert types"), - Err(_) => { - if let Ok(_) = url::Url::parse(file_path_as_str) { + CsvReader::from_path(file_path_as_str).map_or_else( + |_| { + if url::Url::parse(file_path_as_str).is_ok() { let file_contents = minreq::get(file_path_as_str) .send() .expect("Could not open URL"); let temp = temp_file::with_contents(file_contents.as_bytes()); - validate_and_read(temp.path().to_str().unwrap()) } else { - panic!( - "The string {} is not a valid URL or file path.", - file_path_as_str - ) + panic!("The string {file_path_as_str} is not a valid URL or file path.") } - } - } + }, + |csv| { + csv.infer_schema(Some(10)) + .has_header( + csv_sniffer::Sniffer::new() + .sniff_path(file_path_as_str) + .expect("Cannot sniff file") + .dialect + .header + .has_header_row, + ) + .finish() + .expect("Cannot read file as CSV") + .drop_nulls(None) + .expect("Cannot remove null values") + .convert_to_float() + .expect("Cannot convert types") + }, + ) } + +/// Trait to convert to a polars `DataFrame`. #[cfg(any(feature = "csv"))] trait Cleanup { + /// Convert to a polars `DataFrame` with all columns of type float. fn convert_to_float(self) -> Result; } #[cfg(any(feature = "csv"))] impl Cleanup for DataFrame { + #[allow(unused_mut)] fn convert_to_float(mut self) -> Result { // Work in progress // for field in self.schema().fields() { diff --git a/tests/classification.rs b/tests/classification.rs index 7c89073..4af0bcb 100644 --- a/tests/classification.rs +++ b/tests/classification.rs @@ -17,7 +17,7 @@ mod classification_tests { classifier.train(); // Try to predict something - classifier.predict(vec![vec![5.0 as f32; 30]; 10]); + classifier.predict(vec![vec![5.0_f32; 30]; 10]); classifier.predict("data/breast_cancer_without_target.csv"); #[cfg(feature = "nd")] classifier.predict(ndarray::Array2::from_shape_vec((10, 30), vec![5.0; 300]).unwrap()); @@ -58,6 +58,6 @@ mod classification_tests { classifier.train(); // Try to predict something - classifier.predict(vec![vec![5.0 as f32; 30]; 10]); + classifier.predict(vec![vec![5.0_f32; 30]; 10]); } } diff --git a/tests/new_from_dataset.rs b/tests/new_from_dataset.rs index 0d7ca95..75b267e 100644 --- a/tests/new_from_dataset.rs +++ b/tests/new_from_dataset.rs @@ -16,7 +16,7 @@ mod new_from_dataset { classifier.train(); // Try to predict something from a vector - classifier.predict(vec![vec![5.0 as f32; 30]; 10]); + classifier.predict(vec![vec![5.0_f32; 30]; 10]); // Try to predict something from ndarray #[cfg(feature = "nd")] @@ -37,7 +37,7 @@ mod new_from_dataset { regressor.train(); // Try to predict something from a vector - regressor.predict(vec![vec![5.0 as f32; 10]; 10]); + regressor.predict(vec![vec![5.0_f32; 10]; 10]); // Try to predict something from ndarray #[cfg(feature = "nd")] diff --git a/tests/regression.rs b/tests/regression.rs index 127efc4..a395e1e 100644 --- a/tests/regression.rs +++ b/tests/regression.rs @@ -17,7 +17,7 @@ mod regression_tests { regressor.train(); // Try to predict something - regressor.predict(vec![vec![5.0 as f32; 10]; 10]); + regressor.predict(vec![vec![5.0_f32; 10]; 10]); regressor.predict("data/diabetes_without_target.csv"); #[cfg(feature = "nd")] regressor.predict(ndarray::Array2::from_shape_vec((10, 10), vec![5.0; 100]).unwrap()); @@ -38,7 +38,7 @@ mod regression_tests { regressor.train(); // Try to predict something - regressor.predict(vec![vec![5.0 as f32; 8]; 8]); + regressor.predict(vec![vec![5.0_f32; 8]; 8]); } #[test] @@ -76,6 +76,6 @@ mod regression_tests { regressor.train(); // Try to predict something - regressor.predict(vec![vec![5.0 as f32; 10]; 10]); + regressor.predict(vec![vec![5.0_f32; 10]; 10]); } }