Skip to content

Commit 583a41f

Browse files
committed
maintain original feature and target names after shuffling
1 parent d910389 commit 583a41f

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/dataset/impl_dataset.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,22 @@ where
577577
/// ### Returns
578578
///
579579
/// A new shuffled version of the current Dataset
580+
/// # Example
581+
/// ```
582+
/// let dataset = linfa_datasets::iris();
583+
584+
/// println!("First 5 rows {:?}", dataset.records.slice(s![0..5,..]));
585+
/// let feature_names = dataset.feature_names();
586+
/// let target_names = dataset.target_names();
587+
588+
/// let mut rng = thread_rng();
589+
/// let shuffled = dataset.shuffle(&mut rng);
590+
591+
/// println!("First 5 rows after shuffling {:?}", shuffled.records.slice(s![0..5,..]));
592+
/// assert_eq!(feature_names, shuffled.feature_names());
593+
/// /// assert_eq!(target_names, shuffled.target_names());
594+
///
595+
/// ```
580596
pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
581597
let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
582598
indices.shuffle(rng);
@@ -585,7 +601,7 @@ where
585601
let targets = self.as_targets().select(Axis(0), &indices);
586602
let targets = T::new_targets(targets);
587603

588-
DatasetBase::new(records, targets)
604+
DatasetBase::new(records, targets).with_feature_names(self.feature_names().to_vec()).with_target_names(self.target_names().to_vec())
589605
}
590606

591607
#[allow(clippy::type_complexity)]

0 commit comments

Comments
 (0)