@@ -577,6 +577,22 @@ where
577
577
/// ### Returns
578
578
///
579
579
/// A new shuffled version of the current Dataset
580
+ /// # Example
581
+ /// ```
582
+ /// let dataset = linfa_datasets::iris();
583
+
584
+ /// println!("First 5 rows {:?}", dataset.records.slice(s![0..5,..]));
585
+ /// let feature_names = dataset.feature_names();
586
+ /// let target_names = dataset.target_names();
587
+
588
+ /// let mut rng = thread_rng();
589
+ /// let shuffled = dataset.shuffle(&mut rng);
590
+
591
+ /// println!("First 5 rows after shuffling {:?}", shuffled.records.slice(s![0..5,..]));
592
+ /// assert_eq!(feature_names, shuffled.feature_names());
593
+ /// /// assert_eq!(target_names, shuffled.target_names());
594
+ ///
595
+ /// ```
580
596
pub fn shuffle < R : Rng > ( & self , rng : & mut R ) -> DatasetBase < Array2 < F > , T :: Owned > {
581
597
let mut indices = ( 0 ..self . nsamples ( ) ) . collect :: < Vec < _ > > ( ) ;
582
598
indices. shuffle ( rng) ;
@@ -585,7 +601,7 @@ where
585
601
let targets = self . as_targets ( ) . select ( Axis ( 0 ) , & indices) ;
586
602
let targets = T :: new_targets ( targets) ;
587
603
588
- DatasetBase :: new ( records, targets)
604
+ DatasetBase :: new ( records, targets) . with_feature_names ( self . feature_names ( ) . to_vec ( ) ) . with_target_names ( self . target_names ( ) . to_vec ( ) )
589
605
}
590
606
591
607
#[ allow( clippy:: type_complexity) ]
0 commit comments