Skip to content

Commit

Permalink
Add target names to Dataset (#373)
Browse files Browse the repository at this point in the history
* add target_names

* target_naming w/ test

* add target name

* use ntargets in target_names

* move target_names func to where traits bounds are satisfied

* rustc fmt

* address review

* address target_names review comment

* Adress review comments

* Format

* Compute coverage even when PR is draft

* Add panic tests

---------

Co-authored-by: Femi <[email protected]>
  • Loading branch information
relf and oojo12 authored Feb 3, 2025
1 parent a30e5f1 commit 9744625
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codequality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
needs: codequality
name: coverage
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false && (github.event_name == 'pull_request' || github.ref == 'refs/heads/master')
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/master'

steps:
- name: Checkout sources
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/linear_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
/// Substitutes the records of the dataset with their scaled version.
/// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -575,7 +577,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = LinearScaler::standard()
.fit(&dataset)
.unwrap()
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/norm_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
{
/// Substitutes the records of the dataset with their scaled versions with unit norm.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -160,7 +162,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = NormScaler::l2().transform(dataset);
assert_eq!(original_feature_names, transformed.feature_names())
}
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/whitening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
for FittedWhitener<F>
{
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -334,7 +336,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = Whitener::cholesky()
.fit(&dataset)
.unwrap()
Expand Down
8 changes: 7 additions & 1 deletion algorithms/linfa-trees/src/decision_trees/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,13 @@ where
/// a matrix of features `x` and an array of labels `y`.
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let x = dataset.records();
let feature_names = dataset.feature_names();
let feature_names = if dataset.feature_names().is_empty() {
(0..x.nfeatures())
.map(|idx| format!("feature-{idx}"))
.collect()
} else {
dataset.feature_names().to_vec()
};
let all_idxs = RowMask::all(x.nrows());
let sorted_indices: Vec<_> = (0..(x.ncols()))
.map(|feature_idx| {
Expand Down
9 changes: 8 additions & 1 deletion datasets/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,11 @@ pub fn linnerud() -> Dataset<f64, f64> {
let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();

let feature_names = vec!["Chins", "Situps", "Jumps"];
let target_names = vec!["Weight", "Waist", "Pulse"];

Dataset::new(input_array, output_array).with_feature_names(feature_names)
Dataset::new(input_array, output_array)
.with_feature_names(feature_names)
.with_target_names(target_names)
}

#[cfg(test)]
Expand Down Expand Up @@ -261,6 +264,10 @@ mod tests {
let feature_names = vec!["Chins", "Situps", "Jumps"];
assert_eq!(ds.feature_names(), feature_names);

// check for target names
let target_names = vec!["Weight", "Waist", "Pulse"];
assert_eq!(ds.target_names(), target_names);

// get the mean per target: Weight, Waist, Pulse
let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
Expand Down
2 changes: 1 addition & 1 deletion src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl<F: Float> PearsonCorrelation<F> {
PearsonCorrelation {
pearson_coeffs,
p_values,
feature_names: dataset.feature_names(),
feature_names: dataset.feature_names().to_vec(),
}
}

Expand Down
70 changes: 51 additions & 19 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<R: Records, S> DatasetBase<R, S> {
targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),
}
}

Expand Down Expand Up @@ -60,14 +61,8 @@ impl<R: Records, S> DatasetBase<R, S> {
/// A feature name gives a human-readable string describing the purpose of a single feature.
/// This allow the reader to understand its purpose while analysing results, for example
/// correlation analysis or feature importance.
pub fn feature_names(&self) -> Vec<String> {
if !self.feature_names.is_empty() {
self.feature_names.clone()
} else {
(0..self.records.nfeatures())
.map(|idx| format!("feature-{idx}"))
.collect()
}
pub fn feature_names(&self) -> &[String] {
&self.feature_names
}

/// Return records of a dataset
Expand All @@ -81,13 +76,14 @@ impl<R: Records, S> DatasetBase<R, S> {
/// Updates the records of a dataset
///
/// This function overwrites the records in a dataset. It also invalidates the weights and
/// feature names.
/// feature/target names.
pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
DatasetBase {
records,
targets: self.targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),
}
}

Expand All @@ -100,6 +96,7 @@ impl<R: Records, S> DatasetBase<R, S> {
targets,
weights: self.weights,
feature_names: self.feature_names,
target_names: self.target_names,
}
}

Expand All @@ -111,11 +108,14 @@ impl<R: Records, S> DatasetBase<R, S> {
}

/// Updates the feature names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of features
pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
let feature_names = names.into_iter().map(|x| x.into()).collect();

self.feature_names = feature_names;

assert!(
names.is_empty() || names.len() == self.nfeatures(),
"Wrong number of feature names"
);
self.feature_names = names.into_iter().map(|x| x.into()).collect();
self
}
}
Expand All @@ -131,6 +131,18 @@ impl<X, Y> Dataset<X, Y> {
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
/// Updates the target names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of targets
pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
assert!(
names.is_empty() || names.len() == self.ntargets(),
"Wrong number of target names"
);
self.target_names = names.into_iter().map(|x| x.into()).collect();
self
}

/// Map targets with a function `f`
///
/// # Example
Expand All @@ -153,6 +165,7 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
targets,
weights,
feature_names,
target_names,
..
} = self;

Expand All @@ -163,9 +176,17 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
targets: targets.map(fnc),
weights,
feature_names,
target_names,
}
}

/// Returns target names
///
/// A target name gives a human-readable string describing the purpose of a single target.
pub fn target_names(&self) -> &[String] {
&self.target_names
}

/// Return the number of targets in the dataset
///
/// # Example
Expand Down Expand Up @@ -217,6 +238,7 @@ impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Creates a view of a dataset
pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
Expand All @@ -226,6 +248,7 @@ where
DatasetBase::new(records, targets)
.with_feature_names(self.feature_names.clone())
.with_weights(self.weights.clone())
.with_target_names(self.target_names.clone())
}

/// Iterate over features
Expand Down Expand Up @@ -268,6 +291,7 @@ impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T
impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
where
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Split dataset into two disjoint chunks
///
Expand Down Expand Up @@ -299,11 +323,13 @@ where
};
let dataset1 = DatasetBase::new(records_first, targets_first)
.with_weights(first_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_target_names(self.target_names.clone());

let dataset2 = DatasetBase::new(records_second, targets_second)
.with_weights(second_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_target_names(self.target_names.clone());

(dataset1, dataset2)
}
Expand Down Expand Up @@ -349,7 +375,8 @@ where
label,
DatasetBase::new(self.records().view(), targets)
.with_feature_names(self.feature_names.clone())
.with_weights(self.weights.clone()),
.with_weights(self.weights.clone())
.with_target_names(self.target_names.clone()),
)
})
.collect())
Expand Down Expand Up @@ -405,6 +432,7 @@ impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
targets: empty_targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),
}
}
}
Expand All @@ -421,6 +449,7 @@ where
targets: rec_tar.1,
weights: Array1::zeros(0),
feature_names: Vec::new(),
target_names: Vec::new(),
}
}
}
Expand Down Expand Up @@ -957,7 +986,8 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
let n2 = self.nsamples() - n1;

let feature_names = self.feature_names();
let feature_names = self.feature_names().to_vec();
let target_names = self.target_names().to_vec();

// split records into two disjoint arrays
let mut array_buf = self.records.into_raw_vec();
Expand Down Expand Up @@ -990,10 +1020,12 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
// create new datasets with attached weights
let dataset1 = Dataset::new(first, first_targets)
.with_weights(self.weights)
.with_feature_names(feature_names.clone());
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());
let dataset2 = Dataset::new(second, second_targets)
.with_weights(second_weights)
.with_feature_names(feature_names);
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());

(dataset1, dataset2)
}
Expand Down
3 changes: 2 additions & 1 deletion src/dataset/impl_targets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
where
T: FromTargetArray<'a, Elem = L>,
T::Owned: Labels<Elem = L>,
T::View: Labels<Elem = L>,
T::View: Labels<Elem = L> + AsTargets,
{
type Owned = CountedTargets<L, T::Owned>;
type View = CountedTargets<L, T::View>;
Expand Down Expand Up @@ -231,6 +231,7 @@ where
weights: Array1::from(weights),
targets,
feature_names: self.feature_names.clone(),
target_names: self.target_names.clone(),
}
}
}
9 changes: 8 additions & 1 deletion src/dataset/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,24 @@ where
if self.target_or_feature && self.dataset.nfeatures() <= self.idx {
return None;
}

let mut records = self.dataset.records.view();
let mut targets = self.dataset.targets.as_targets();
let feature_names;
let target_names;
let weights = self.dataset.weights.clone();

if !self.target_or_feature {
// This branch should only run for 2D targets
targets.collapse_axis(Axis(1), self.idx);
feature_names = self.dataset.feature_names.clone();
if self.dataset.target_names.is_empty() {
target_names = Vec::new();
} else {
target_names = vec![self.dataset.target_names[self.idx].clone()];
}
} else {
records.collapse_axis(Axis(1), self.idx);
target_names = self.dataset.target_names.clone();
if self.dataset.feature_names.len() == records.len_of(Axis(1)) {
feature_names = vec![self.dataset.feature_names[self.idx].clone()];
} else {
Expand All @@ -103,6 +109,7 @@ where
targets,
weights,
feature_names,
target_names,
};

Some(dataset_view)
Expand Down
Loading

0 comments on commit 9744625

Please sign in to comment.