Skip to content

Commit f498f96

Browse files
authored
Implement realnum::rand (#251)
Co-authored-by: Luis Moreno <[email protected]> Co-authored-by: Lorenzo <[email protected]> * Implement rand. Use the new derive [#default] * Use custom range * Use range seed * Bump version * Add array length checks for
1 parent 7d059c4 commit f498f96

12 files changed

+118
-44
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "smartcore"
33
description = "Machine Learning in Rust."
44
homepage = "https://smartcorelib.org"
5-
version = "0.3.0"
5+
version = "0.3.1"
66
authors = ["smartcore Developers"]
77
edition = "2021"
88
license = "Apache-2.0"

src/algorithm/neighbour/mod.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,15 @@ pub mod linear_search;
4949
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
5050
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
5151
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
52-
#[derive(Debug, Clone)]
52+
#[derive(Debug, Clone, Default)]
5353
pub enum KNNAlgorithmName {
5454
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
5555
LinearSearch,
5656
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
57+
#[default]
5758
CoverTree,
5859
}
5960

60-
impl Default for KNNAlgorithmName {
61-
fn default() -> Self {
62-
KNNAlgorithmName::CoverTree
63-
}
64-
}
65-
6661
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6762
#[derive(Debug)]
6863
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {

src/cluster/dbscan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//!
1919
//! Example:
2020
//!
21-
//! ```
21+
//! ```ignore
2222
//! use smartcore::linalg::basic::matrix::DenseMatrix;
2323
//! use smartcore::linalg::basic::arrays::Array2;
2424
//! use smartcore::cluster::dbscan::*;

src/ensemble/random_forest_classifier.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,12 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
454454
y: &Y,
455455
parameters: RandomForestClassifierParameters,
456456
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
457-
let (_, num_attributes) = x.shape();
457+
let (x_nrows, num_attributes) = x.shape();
458458
let y_ncols = y.shape();
459+
if x_nrows != y_ncols {
460+
return Err(Failed::fit("Number of rows in X should = len(y)"));
461+
}
462+
459463
let mut yi: Vec<usize> = vec![0; y_ncols];
460464
let classes = y.unique();
461465

@@ -678,6 +682,30 @@ mod tests {
678682
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
679683
}
680684

685+
#[test]
686+
fn test_random_matrix_with_wrong_rownum() {
687+
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
688+
689+
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
690+
691+
let fail = RandomForestClassifier::fit(
692+
&x_rand,
693+
&y,
694+
RandomForestClassifierParameters {
695+
criterion: SplitCriterion::Gini,
696+
max_depth: Option::None,
697+
min_samples_leaf: 1,
698+
min_samples_split: 2,
699+
n_trees: 100,
700+
m: Option::None,
701+
keep_samples: false,
702+
seed: 87,
703+
},
704+
);
705+
706+
assert!(fail.is_err());
707+
}
708+
681709
#[cfg_attr(
682710
all(target_arch = "wasm32", not(target_os = "wasi")),
683711
wasm_bindgen_test::wasm_bindgen_test

src/ensemble/random_forest_regressor.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,10 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
399399
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
400400
let (n_rows, num_attributes) = x.shape();
401401

402+
if n_rows != y.shape() {
403+
return Err(Failed::fit("Number of rows in X should = len(y)"));
404+
}
405+
402406
let mtry = parameters
403407
.m
404408
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
@@ -595,6 +599,32 @@ mod tests {
595599
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
596600
}
597601

602+
#[test]
603+
fn test_random_matrix_with_wrong_rownum() {
604+
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
605+
606+
let y = vec![
607+
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
608+
114.2, 115.7, 116.9,
609+
];
610+
611+
let fail = RandomForestRegressor::fit(
612+
&x_rand,
613+
&y,
614+
RandomForestRegressorParameters {
615+
max_depth: Option::None,
616+
min_samples_leaf: 1,
617+
min_samples_split: 2,
618+
n_trees: 1000,
619+
m: Option::None,
620+
keep_samples: false,
621+
seed: 87,
622+
},
623+
);
624+
625+
assert!(fail.is_err());
626+
}
627+
598628
#[cfg_attr(
599629
all(target_arch = "wasm32", not(target_os = "wasi")),
600630
wasm_bindgen_test::wasm_bindgen_test

src/error/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub enum FailedError {
3030
DecompositionFailed,
3131
/// Can't solve for x
3232
SolutionFailed,
33-
/// Erro in input
33+
/// Error in input parameters
3434
ParametersError,
3535
}
3636

src/linear/logistic_regression.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,14 @@ use crate::optimization::line_search::Backtracking;
7171
use crate::optimization::FunctionOrder;
7272

7373
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74-
#[derive(Debug, Clone, Eq, PartialEq)]
74+
#[derive(Debug, Clone, Eq, PartialEq, Default)]
7575
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
7676
pub enum LogisticRegressionSolverName {
7777
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
78+
#[default]
7879
LBFGS,
7980
}
8081

81-
impl Default for LogisticRegressionSolverName {
82-
fn default() -> Self {
83-
LogisticRegressionSolverName::LBFGS
84-
}
85-
}
86-
8782
/// Logistic Regression parameters
8883
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8984
#[derive(Debug, Clone)]

src/linear/ridge_regression.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,16 @@ use crate::numbers::basenum::Number;
7171
use crate::numbers::realnum::RealNumber;
7272

7373
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74-
#[derive(Debug, Clone, Eq, PartialEq)]
74+
#[derive(Debug, Clone, Eq, PartialEq, Default)]
7575
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
7676
pub enum RidgeRegressionSolverName {
7777
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
78+
#[default]
7879
Cholesky,
7980
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
8081
SVD,
8182
}
8283

83-
impl Default for RidgeRegressionSolverName {
84-
fn default() -> Self {
85-
RidgeRegressionSolverName::Cholesky
86-
}
87-
}
88-
8984
/// Ridge Regression parameters
9085
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9186
#[derive(Debug, Clone)]

src/neighbors/mod.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,15 @@ pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
4949

5050
/// Weight function that is used to determine estimated value.
5151
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
52-
#[derive(Debug, Clone)]
52+
#[derive(Debug, Clone, Default)]
5353
pub enum KNNWeightFunction {
5454
/// All k nearest points are weighted equally
55+
#[default]
5556
Uniform,
5657
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
5758
Distance,
5859
}
5960

60-
impl Default for KNNWeightFunction {
61-
fn default() -> Self {
62-
KNNWeightFunction::Uniform
63-
}
64-
}
65-
6661
impl KNNWeightFunction {
6762
fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> {
6863
match *self {

src/numbers/realnum.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
33
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
44
5+
use rand::rngs::SmallRng;
6+
use rand::{Rng, SeedableRng};
7+
58
use num_traits::Float;
69

710
use crate::numbers::basenum::Number;
11+
use crate::rand_custom::get_rng_impl;
812

913
/// Defines real number
1014
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
@@ -63,8 +67,12 @@ impl RealNumber for f64 {
6367
}
6468

6569
fn rand() -> f64 {
66-
// TODO: to be implemented, see issue smartcore#214
67-
1.0
70+
let mut small_rng = get_rng_impl(None);
71+
72+
let mut rngs: Vec<SmallRng> = (0..3)
73+
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
74+
.collect();
75+
rngs[0].gen::<f64>()
6876
}
6977

7078
fn two() -> Self {
@@ -108,7 +116,12 @@ impl RealNumber for f32 {
108116
}
109117

110118
fn rand() -> f32 {
111-
1.0
119+
let mut small_rng = get_rng_impl(None);
120+
121+
let mut rngs: Vec<SmallRng> = (0..3)
122+
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
123+
.collect();
124+
rngs[0].gen::<f32>()
112125
}
113126

114127
fn two() -> Self {
@@ -149,4 +162,14 @@ mod tests {
149162
fn f64_from_string() {
150163
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
151164
}
165+
166+
#[test]
167+
fn f64_rand() {
168+
f64::rand();
169+
}
170+
171+
#[test]
172+
fn f32_rand() {
173+
f32::rand();
174+
}
152175
}

src/tree/decision_tree_classifier.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,29 +137,24 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
137137
self.classes.as_ref()
138138
}
139139
/// Get depth of tree
140-
fn depth(&self) -> u16 {
140+
pub fn depth(&self) -> u16 {
141141
self.depth
142142
}
143143
}
144144

145145
/// The function to measure the quality of a split.
146146
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147-
#[derive(Debug, Clone)]
147+
#[derive(Debug, Clone, Default)]
148148
pub enum SplitCriterion {
149149
/// [Gini index](../decision_tree_classifier/index.html)
150+
#[default]
150151
Gini,
151152
/// [Entropy](../decision_tree_classifier/index.html)
152153
Entropy,
153154
/// [Classification error](../decision_tree_classifier/index.html)
154155
ClassificationError,
155156
}
156157

157-
impl Default for SplitCriterion {
158-
fn default() -> Self {
159-
SplitCriterion::Gini
160-
}
161-
}
162-
163158
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
164159
#[derive(Debug, Clone)]
165160
struct Node {
@@ -543,6 +538,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
543538
parameters: DecisionTreeClassifierParameters,
544539
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
545540
let (x_nrows, num_attributes) = x.shape();
541+
if x_nrows != y.shape() {
542+
return Err(Failed::fit("Size of x should equal size of y"));
543+
}
544+
546545
let samples = vec![1; x_nrows];
547546
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
548547
}
@@ -968,6 +967,17 @@ mod tests {
968967
);
969968
}
970969

970+
#[test]
971+
fn test_random_matrix_with_wrong_rownum() {
972+
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
973+
974+
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
975+
976+
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
977+
978+
assert!(fail.is_err());
979+
}
980+
971981
#[cfg_attr(
972982
all(target_arch = "wasm32", not(target_os = "wasi")),
973983
wasm_bindgen_test::wasm_bindgen_test

src/tree/decision_tree_regressor.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
//! Example:
1919
//!
2020
//! ```
21-
//! use rand::thread_rng;
2221
//! use smartcore::linalg::basic::matrix::DenseMatrix;
2322
//! use smartcore::tree::decision_tree_regressor::*;
2423
//!
@@ -422,6 +421,10 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
422421
parameters: DecisionTreeRegressorParameters,
423422
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
424423
let (x_nrows, num_attributes) = x.shape();
424+
if x_nrows != y.shape() {
425+
return Err(Failed::fit("Size of x should equal size of y"));
426+
}
427+
425428
let samples = vec![1; x_nrows];
426429
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
427430
}

0 commit comments

Comments
 (0)