Skip to content

Commit 05b07aa

Browse files
committed
moved test example to mod
1 parent ce7fc55 commit 05b07aa

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/dataset/impl_dataset.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -577,19 +577,6 @@ where
577577
/// ### Returns
578578
///
579579
/// A new shuffled version of the current Dataset
580-
/// # Example
581-
/// ```
582-
/// use rand::thread_rng;
583-
/// use ndarray::s;
584-
/// let dataset = linfa_datasets::iris();
585-
/// println!("First 5 rows {:?}", dataset.records.slice(s![0..5,..]));
586-
/// println!("Feature names {:?}", dataset.feature_names());
587-
/// println!("Target names {:?}", dataset.target_names());
588-
/// let mut rng = thread_rng();
589-
/// let shuffled = dataset.shuffle(&mut rng);
590-
/// println!("First 5 rows after shuffling {:?}", shuffled.records.slice(s![0..5,..]));
591-
/// println!("Feature names after shuffling {:?}", shuffled.feature_names());
592-
/// println!("Target names after shuffling {:?}", shuffled.target_names());
593580
/// ```
594581
pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
595582
let mut indices = (0..self.nsamples()).collect::<Vec<_>>();

src/dataset/mod.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ pub trait Labels {
353353
mod tests {
354354
use super::*;
355355
use crate::error::Error;
356-
use approx::assert_abs_diff_eq;
356+
use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
357357
use linfa_datasets::generate::make_dataset;
358358
use ndarray::{array, Array1, Array2, Axis};
359359
use rand::{rngs::SmallRng, SeedableRng};
@@ -1050,4 +1050,24 @@ mod tests {
10501050
let prob = -0.5;
10511051
assert_abs_diff_eq!(Pr::new_unchecked(prob).0, prob);
10521052
}
1053+
1054+
#[test]
1055+
fn test_dataset_shuffle() {
1056+
let mut rng = SmallRng::seed_from_u64(42);
1057+
let f_names = vec!["f1", "f2", "f3"];
1058+
let t_names = vec!["t1"];
1059+
let dataset = Dataset::new(
1060+
array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
1061+
array![0., 1., 3.],
1062+
)
1063+
.with_feature_names(f_names.clone())
1064+
.with_target_names(t_names.clone());
1065+
1066+
let shuffled = dataset.shuffle(&mut rng);
1067+
1068+
assert_abs_diff_ne!(dataset.records(), shuffled.records());
1069+
assert_abs_diff_ne!(dataset.targets(), shuffled.targets());
1070+
assert_eq!(f_names, shuffled.feature_names());
1071+
assert_eq!(t_names, shuffled.target_names());
1072+
}
10531073
}

0 commit comments

Comments
 (0)