Skip to content

Commit

Permalink
moved test example to mod
Browse files Browse the repository at this point in the history
  • Loading branch information
Plutone11011 committed Feb 10, 2025
1 parent ce7fc55 commit 05b07aa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
13 changes: 0 additions & 13 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,19 +577,6 @@ where
/// ### Returns
///
/// A new shuffled version of the current Dataset
/// # Example
/// ```
/// use rand::thread_rng;
/// use ndarray::s;
/// let dataset = linfa_datasets::iris();
/// println!("First 5 rows {:?}", dataset.records.slice(s![0..5,..]));
/// println!("Feature names {:?}", dataset.feature_names());
/// println!("Target names {:?}", dataset.target_names());
/// let mut rng = thread_rng();
/// let shuffled = dataset.shuffle(&mut rng);
/// println!("First 5 rows after shuffling {:?}", shuffled.records.slice(s![0..5,..]));
/// println!("Feature names after shuffling {:?}", shuffled.feature_names());
/// println!("Target names after shuffling {:?}", shuffled.target_names());
/// ```
pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
Expand Down
22 changes: 21 additions & 1 deletion src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ pub trait Labels {
mod tests {
use super::*;
use crate::error::Error;
use approx::assert_abs_diff_eq;
use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
use linfa_datasets::generate::make_dataset;
use ndarray::{array, Array1, Array2, Axis};
use rand::{rngs::SmallRng, SeedableRng};
Expand Down Expand Up @@ -1050,4 +1050,24 @@ mod tests {
let prob = -0.5;
assert_abs_diff_eq!(Pr::new_unchecked(prob).0, prob);
}

#[test]
fn test_dataset_shuffle() {
let mut rng = SmallRng::seed_from_u64(42);
let f_names = vec!["f1", "f2", "f3"];
let t_names = vec!["t1"];
let dataset = Dataset::new(
array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
array![0., 1., 3.],
)
.with_feature_names(f_names.clone())
.with_target_names(t_names.clone());

let shuffled = dataset.shuffle(&mut rng);

assert_abs_diff_ne!(dataset.records(), shuffled.records());
assert_abs_diff_ne!(dataset.targets(), shuffled.targets());
assert_eq!(f_names, shuffled.feature_names());
assert_eq!(t_names, shuffled.target_names());
}
}

0 comments on commit 05b07aa

Please sign in to comment.