Skip to content

Commit 9744625

Browse files
relfoojo12
andauthored
Add target names to Dataset (#373)
* add target_names * target_naming w/ test * add target name * use ntargets in target_names * move target_names func to where traits bounds are satisfied * rustc fmt * address review * address target_names review comment * Adress review comments * Format * Compute coverage even when PR is draft * Add panic tests --------- Co-authored-by: Femi <[email protected]>
1 parent a30e5f1 commit 9744625

File tree

11 files changed

+128
-32
lines changed

11 files changed

+128
-32
lines changed

.github/workflows/codequality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
needs: codequality
3434
name: coverage
3535
runs-on: ubuntu-latest
36-
if: github.event.pull_request.draft == false && (github.event_name == 'pull_request' || github.ref == 'refs/heads/master')
36+
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/master'
3737

3838
steps:
3939
- name: Checkout sources

algorithms/linfa-preprocessing/src/linear_scaling.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
307307
/// Substitutes the records of the dataset with their scaled version.
308308
/// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting.
309309
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
310-
let feature_names = x.feature_names();
310+
let feature_names = x.feature_names().to_vec();
311+
let target_names = x.target_names().to_vec();
311312
let (records, targets, weights) = (x.records, x.targets, x.weights);
312313
let records = self.transform(records.to_owned());
313314
DatasetBase::new(records, targets)
314315
.with_weights(weights)
315316
.with_feature_names(feature_names)
317+
.with_target_names(target_names)
316318
}
317319
}
318320

@@ -575,7 +577,7 @@ mod tests {
575577
#[test]
576578
fn test_retain_feature_names() {
577579
let dataset = linfa_datasets::diabetes();
578-
let original_feature_names = dataset.feature_names();
580+
let original_feature_names = dataset.feature_names().to_vec();
579581
let transformed = LinearScaler::standard()
580582
.fit(&dataset)
581583
.unwrap()

algorithms/linfa-preprocessing/src/norm_scaling.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
9494
{
9595
/// Substitutes the records of the dataset with their scaled versions with unit norm.
9696
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
97-
let feature_names = x.feature_names();
97+
let feature_names = x.feature_names().to_vec();
98+
let target_names = x.target_names().to_vec();
9899
let (records, targets, weights) = (x.records, x.targets, x.weights);
99100
let records = self.transform(records.to_owned());
100101
DatasetBase::new(records, targets)
101102
.with_weights(weights)
102103
.with_feature_names(feature_names)
104+
.with_target_names(target_names)
103105
}
104106
}
105107

@@ -160,7 +162,7 @@ mod tests {
160162
#[test]
161163
fn test_retain_feature_names() {
162164
let dataset = linfa_datasets::diabetes();
163-
let original_feature_names = dataset.feature_names();
165+
let original_feature_names = dataset.feature_names().to_vec();
164166
let transformed = NormScaler::l2().transform(dataset);
165167
assert_eq!(original_feature_names, transformed.feature_names())
166168
}

algorithms/linfa-preprocessing/src/whitening.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
209209
for FittedWhitener<F>
210210
{
211211
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
212-
let feature_names = x.feature_names();
212+
let feature_names = x.feature_names().to_vec();
213+
let target_names = x.target_names().to_vec();
213214
let (records, targets, weights) = (x.records, x.targets, x.weights);
214215
let records = self.transform(records.to_owned());
215216
DatasetBase::new(records, targets)
216217
.with_weights(weights)
217218
.with_feature_names(feature_names)
219+
.with_target_names(target_names)
218220
}
219221
}
220222

@@ -334,7 +336,7 @@ mod tests {
334336
#[test]
335337
fn test_retain_feature_names() {
336338
let dataset = linfa_datasets::diabetes();
337-
let original_feature_names = dataset.feature_names();
339+
let original_feature_names = dataset.feature_names().to_vec();
338340
let transformed = Whitener::cholesky()
339341
.fit(&dataset)
340342
.unwrap()

algorithms/linfa-trees/src/decision_trees/algorithm.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,13 @@ where
523523
/// a matrix of features `x` and an array of labels `y`.
524524
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
525525
let x = dataset.records();
526-
let feature_names = dataset.feature_names();
526+
let feature_names = if dataset.feature_names().is_empty() {
527+
(0..x.nfeatures())
528+
.map(|idx| format!("feature-{idx}"))
529+
.collect()
530+
} else {
531+
dataset.feature_names().to_vec()
532+
};
527533
let all_idxs = RowMask::all(x.nrows());
528534
let sorted_indices: Vec<_> = (0..(x.ncols()))
529535
.map(|feature_idx| {

datasets/src/dataset.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,11 @@ pub fn linnerud() -> Dataset<f64, f64> {
131131
let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132132

133133
let feature_names = vec!["Chins", "Situps", "Jumps"];
134+
let target_names = vec!["Weight", "Waist", "Pulse"];
134135

135-
Dataset::new(input_array, output_array).with_feature_names(feature_names)
136+
Dataset::new(input_array, output_array)
137+
.with_feature_names(feature_names)
138+
.with_target_names(target_names)
136139
}
137140

138141
#[cfg(test)]
@@ -261,6 +264,10 @@ mod tests {
261264
let feature_names = vec!["Chins", "Situps", "Jumps"];
262265
assert_eq!(ds.feature_names(), feature_names);
263266

267+
// check for target names
268+
let target_names = vec!["Weight", "Waist", "Pulse"];
269+
assert_eq!(ds.target_names(), target_names);
270+
264271
// get the mean per target: Weight, Waist, Pulse
265272
let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
266273
assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);

src/correlation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl<F: Float> PearsonCorrelation<F> {
169169
PearsonCorrelation {
170170
pearson_coeffs,
171171
p_values,
172-
feature_names: dataset.feature_names(),
172+
feature_names: dataset.feature_names().to_vec(),
173173
}
174174
}
175175

src/dataset/impl_dataset.rs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ impl<R: Records, S> DatasetBase<R, S> {
3030
targets,
3131
weights: Array1::zeros(0),
3232
feature_names: Vec::new(),
33+
target_names: Vec::new(),
3334
}
3435
}
3536

@@ -60,14 +61,8 @@ impl<R: Records, S> DatasetBase<R, S> {
6061
/// A feature name gives a human-readable string describing the purpose of a single feature.
6162
/// This allow the reader to understand its purpose while analysing results, for example
6263
/// correlation analysis or feature importance.
63-
pub fn feature_names(&self) -> Vec<String> {
64-
if !self.feature_names.is_empty() {
65-
self.feature_names.clone()
66-
} else {
67-
(0..self.records.nfeatures())
68-
.map(|idx| format!("feature-{idx}"))
69-
.collect()
70-
}
64+
pub fn feature_names(&self) -> &[String] {
65+
&self.feature_names
7166
}
7267

7368
/// Return records of a dataset
@@ -81,13 +76,14 @@ impl<R: Records, S> DatasetBase<R, S> {
8176
/// Updates the records of a dataset
8277
///
8378
/// This function overwrites the records in a dataset. It also invalidates the weights and
84-
/// feature names.
79+
/// feature/target names.
8580
pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
8681
DatasetBase {
8782
records,
8883
targets: self.targets,
8984
weights: Array1::zeros(0),
9085
feature_names: Vec::new(),
86+
target_names: Vec::new(),
9187
}
9288
}
9389

@@ -100,6 +96,7 @@ impl<R: Records, S> DatasetBase<R, S> {
10096
targets,
10197
weights: self.weights,
10298
feature_names: self.feature_names,
99+
target_names: self.target_names,
103100
}
104101
}
105102

@@ -111,11 +108,14 @@ impl<R: Records, S> DatasetBase<R, S> {
111108
}
112109

113110
/// Updates the feature names of a dataset
111+
///
112+
/// **Panics** when given names not empty and length does not equal to the number of features
114113
pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
115-
let feature_names = names.into_iter().map(|x| x.into()).collect();
116-
117-
self.feature_names = feature_names;
118-
114+
assert!(
115+
names.is_empty() || names.len() == self.nfeatures(),
116+
"Wrong number of feature names"
117+
);
118+
self.feature_names = names.into_iter().map(|x| x.into()).collect();
119119
self
120120
}
121121
}
@@ -131,6 +131,18 @@ impl<X, Y> Dataset<X, Y> {
131131
}
132132

133133
impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
134+
/// Updates the target names of a dataset
135+
///
136+
/// **Panics** when given names not empty and length does not equal to the number of targets
137+
pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
138+
assert!(
139+
names.is_empty() || names.len() == self.ntargets(),
140+
"Wrong number of target names"
141+
);
142+
self.target_names = names.into_iter().map(|x| x.into()).collect();
143+
self
144+
}
145+
134146
/// Map targets with a function `f`
135147
///
136148
/// # Example
@@ -153,6 +165,7 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
153165
targets,
154166
weights,
155167
feature_names,
168+
target_names,
156169
..
157170
} = self;
158171

@@ -163,9 +176,17 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
163176
targets: targets.map(fnc),
164177
weights,
165178
feature_names,
179+
target_names,
166180
}
167181
}
168182

183+
/// Returns target names
184+
///
185+
/// A target name gives a human-readable string describing the purpose of a single target.
186+
pub fn target_names(&self) -> &[String] {
187+
&self.target_names
188+
}
189+
169190
/// Return the number of targets in the dataset
170191
///
171192
/// # Example
@@ -217,6 +238,7 @@ impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
217238
where
218239
D: Data<Elem = F>,
219240
T: AsTargets<Elem = L> + FromTargetArray<'a>,
241+
T::View: AsTargets<Elem = L>,
220242
{
221243
/// Creates a view of a dataset
222244
pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
@@ -226,6 +248,7 @@ where
226248
DatasetBase::new(records, targets)
227249
.with_feature_names(self.feature_names.clone())
228250
.with_weights(self.weights.clone())
251+
.with_target_names(self.target_names.clone())
229252
}
230253

231254
/// Iterate over features
@@ -268,6 +291,7 @@ impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T
268291
impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
269292
where
270293
T: AsTargets<Elem = L> + FromTargetArray<'a>,
294+
T::View: AsTargets<Elem = L>,
271295
{
272296
/// Split dataset into two disjoint chunks
273297
///
@@ -299,11 +323,13 @@ where
299323
};
300324
let dataset1 = DatasetBase::new(records_first, targets_first)
301325
.with_weights(first_weights)
302-
.with_feature_names(self.feature_names.clone());
326+
.with_feature_names(self.feature_names.clone())
327+
.with_target_names(self.target_names.clone());
303328

304329
let dataset2 = DatasetBase::new(records_second, targets_second)
305330
.with_weights(second_weights)
306-
.with_feature_names(self.feature_names.clone());
331+
.with_feature_names(self.feature_names.clone())
332+
.with_target_names(self.target_names.clone());
307333

308334
(dataset1, dataset2)
309335
}
@@ -349,7 +375,8 @@ where
349375
label,
350376
DatasetBase::new(self.records().view(), targets)
351377
.with_feature_names(self.feature_names.clone())
352-
.with_weights(self.weights.clone()),
378+
.with_weights(self.weights.clone())
379+
.with_target_names(self.target_names.clone()),
353380
)
354381
})
355382
.collect())
@@ -405,6 +432,7 @@ impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
405432
targets: empty_targets,
406433
weights: Array1::zeros(0),
407434
feature_names: Vec::new(),
435+
target_names: Vec::new(),
408436
}
409437
}
410438
}
@@ -421,6 +449,7 @@ where
421449
targets: rec_tar.1,
422450
weights: Array1::zeros(0),
423451
feature_names: Vec::new(),
452+
target_names: Vec::new(),
424453
}
425454
}
426455
}
@@ -957,7 +986,8 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
957986
let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
958987
let n2 = self.nsamples() - n1;
959988

960-
let feature_names = self.feature_names();
989+
let feature_names = self.feature_names().to_vec();
990+
let target_names = self.target_names().to_vec();
961991

962992
// split records into two disjoint arrays
963993
let mut array_buf = self.records.into_raw_vec();
@@ -990,10 +1020,12 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
9901020
// create new datasets with attached weights
9911021
let dataset1 = Dataset::new(first, first_targets)
9921022
.with_weights(self.weights)
993-
.with_feature_names(feature_names.clone());
1023+
.with_feature_names(feature_names.clone())
1024+
.with_target_names(target_names.clone());
9941025
let dataset2 = Dataset::new(second, second_targets)
9951026
.with_weights(second_weights)
996-
.with_feature_names(feature_names);
1027+
.with_feature_names(feature_names.clone())
1028+
.with_target_names(target_names.clone());
9971029

9981030
(dataset1, dataset2)
9991031
}

src/dataset/impl_targets.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
8383
where
8484
T: FromTargetArray<'a, Elem = L>,
8585
T::Owned: Labels<Elem = L>,
86-
T::View: Labels<Elem = L>,
86+
T::View: Labels<Elem = L> + AsTargets,
8787
{
8888
type Owned = CountedTargets<L, T::Owned>;
8989
type View = CountedTargets<L, T::View>;
@@ -231,6 +231,7 @@ where
231231
weights: Array1::from(weights),
232232
targets,
233233
feature_names: self.feature_names.clone(),
234+
target_names: self.target_names.clone(),
234235
}
235236
}
236237
}

src/dataset/iter.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,24 @@ where
7777
if self.target_or_feature && self.dataset.nfeatures() <= self.idx {
7878
return None;
7979
}
80-
8180
let mut records = self.dataset.records.view();
8281
let mut targets = self.dataset.targets.as_targets();
8382
let feature_names;
83+
let target_names;
8484
let weights = self.dataset.weights.clone();
8585

8686
if !self.target_or_feature {
8787
// This branch should only run for 2D targets
8888
targets.collapse_axis(Axis(1), self.idx);
8989
feature_names = self.dataset.feature_names.clone();
90+
if self.dataset.target_names.is_empty() {
91+
target_names = Vec::new();
92+
} else {
93+
target_names = vec![self.dataset.target_names[self.idx].clone()];
94+
}
9095
} else {
9196
records.collapse_axis(Axis(1), self.idx);
97+
target_names = self.dataset.target_names.clone();
9298
if self.dataset.feature_names.len() == records.len_of(Axis(1)) {
9399
feature_names = vec![self.dataset.feature_names[self.idx].clone()];
94100
} else {
@@ -103,6 +109,7 @@ where
103109
targets,
104110
weights,
105111
feature_names,
112+
target_names,
106113
};
107114

108115
Some(dataset_view)

0 commit comments

Comments
 (0)