Skip to content

Commit 31a78fb

Browse files
add max_features and tokenizer to CountVectorizer
1 parent a30e5f1 commit 31a78fb

File tree

7 files changed

+128
-21
lines changed

7 files changed

+128
-21
lines changed

algorithms/linfa-preprocessing/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ encoding = "0.2"
3232
sprs = { version = "=0.11.1", default-features = false }
3333

3434
serde_regex = { version = "1.1", optional = true }
35+
itertools = "0.14.0"
3536

3637
[dependencies.serde_crate]
3738
package = "serde"

algorithms/linfa-preprocessing/examples/count_vectorization.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ fn main() {
126126
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
127127
let training_records = vectorizer
128128
.transform_files(&training_filenames, ISO_8859_1, Strict)
129+
.unwrap()
129130
.to_dense();
130131
// Currently linfa only allows real valued features so we have to transform the integer counts to floats
131132
let training_records = training_records.mapv(|c| c as f32);
@@ -164,6 +165,7 @@ fn main() {
164165
);
165166
let test_records = vectorizer
166167
.transform_files(&test_filenames, ISO_8859_1, Strict)
168+
.unwrap()
167169
.to_dense();
168170
let test_records = test_records.mapv(|c| c as f32);
169171
let test_dataset: Dataset<f32, usize, Ix1> = (test_records, test_targets).into();

algorithms/linfa-preprocessing/examples/tfidf_vectorization.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ fn main() {
126126
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
127127
let training_records = vectorizer
128128
.transform_files(&training_filenames, ISO_8859_1, Strict)
129+
.unwrap()
129130
.to_dense();
130131

131132
println!(
@@ -162,6 +163,7 @@ fn main() {
162163
);
163164
let test_records = vectorizer
164165
.transform_files(&test_filenames, ISO_8859_1, Strict)
166+
.unwrap()
165167
.to_dense();
166168
let test_dataset: Dataset<f64, usize, Ix1> = (test_records, test_targets).into();
167169
// Let's predict the test data targets

algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use std::collections::HashSet;
77
#[cfg(feature = "serde")]
88
use serde_crate::{Deserialize, Serialize};
99

10+
use super::TOKENIZERFP;
11+
1012
#[derive(Clone, Debug)]
1113
#[cfg(not(feature = "serde"))]
1214
struct SerdeRegex(Regex);
@@ -71,9 +73,21 @@ pub struct CountVectorizerValidParams {
7173
normalize: bool,
7274
document_frequency: (f32, f32),
7375
stopwords: Option<HashSet<String>>,
76+
max_features: Option<usize>,
77+
#[cfg_attr(feature = "serde", serde(skip))]
78+
pub(crate) tokenizer: Option<TOKENIZERFP>,
79+
pub(crate) tokenizer_deserialization_guard: bool,
7480
}
7581

7682
impl CountVectorizerValidParams {
83+
pub fn tokenizer(&self) -> Option<TOKENIZERFP> {
84+
self.tokenizer
85+
}
86+
87+
pub fn max_features(&self) -> Option<usize> {
88+
self.max_features
89+
}
90+
7791
pub fn convert_to_lowercase(&self) -> bool {
7892
self.convert_to_lowercase
7993
}
@@ -117,11 +131,25 @@ impl std::default::Default for CountVectorizerParams {
117131
normalize: true,
118132
document_frequency: (0., 1.),
119133
stopwords: None,
134+
max_features: None,
135+
tokenizer: None,
136+
tokenizer_deserialization_guard: false,
120137
})
121138
}
122139
}
123140

124141
impl CountVectorizerParams {
142+
pub fn tokenizer(mut self, tokenizer: Option<TOKENIZERFP>) -> Self {
143+
self.0.tokenizer = tokenizer;
144+
self.0.tokenizer_deserialization_guard = tokenizer.is_some();
145+
self
146+
}
147+
148+
pub fn max_features(mut self, max_features: Option<usize>) -> Self {
149+
self.0.max_features = max_features;
150+
self
151+
}
152+
125153
///If true, all documents used for fitting will be converted to lowercase.
126154
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
127155
self.0.convert_to_lowercase = convert_to_lowercase;

algorithms/linfa-preprocessing/src/countgrams/mod.rs

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
//! Count vectorization methods
22
3+
use std::cmp::Reverse;
34
use std::collections::{HashMap, HashSet};
45
use std::io::Read;
56
use std::iter::IntoIterator;
67

78
use encoding::types::EncodingRef;
89
use encoding::DecoderTrap;
10+
use itertools::sorted;
911
use ndarray::{Array1, ArrayBase, ArrayViewMut1, Data, Ix1};
1012
use regex::Regex;
1113
use sprs::{CsMat, CsVec};
@@ -21,6 +23,8 @@ use serde_crate::{Deserialize, Serialize};
2123

2224
mod hyperparams;
2325

26+
pub(crate) type TOKENIZERFP = fn(&str) -> Vec<&str>;
27+
2428
impl CountVectorizerValidParams {
2529
/// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
2630
/// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
@@ -41,10 +45,11 @@ impl CountVectorizerValidParams {
4145
}
4246

4347
let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());
48+
4449
let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
4550

4651
Ok(CountVectorizer {
47-
vocabulary,
52+
vocabulary: vocabulary,
4853
vec_vocabulary,
4954
properties: self.clone(),
5055
})
@@ -127,7 +132,7 @@ impl CountVectorizerValidParams {
127132
let len_f32 = n_documents as f32;
128133
let (min_abs_df, max_abs_df) = ((min_df * len_f32) as usize, (max_df * len_f32) as usize);
129134

130-
if min_abs_df == 0 && max_abs_df == n_documents {
135+
let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
131136
match &self.stopwords() {
132137
None => vocabulary,
133138
Some(stopwords) => vocabulary
@@ -152,6 +157,19 @@ impl CountVectorizerValidParams {
152157
})
153158
.collect(),
154159
}
160+
};
161+
162+
if let Some(max_features) = self.max_features() {
163+
sorted(
164+
vocabulary
165+
.into_iter()
166+
.map(|(word, (x, freq))| (Reverse(freq), Reverse(word), x)),
167+
)
168+
.take(max_features)
169+
.map(|(freq, word, x)| (word.0, (x, freq.0)))
170+
.collect()
171+
} else {
172+
vocabulary
155173
}
156174
}
157175

@@ -164,7 +182,11 @@ impl CountVectorizerValidParams {
164182
regex: &Regex,
165183
vocabulary: &mut HashMap<String, (usize, usize)>,
166184
) {
167-
let words = regex.find_iter(&doc).map(|mat| mat.as_str()).collect();
185+
let words = if let Some(tokenizer) = self.tokenizer() {
186+
tokenizer(&doc)
187+
} else {
188+
regex.find_iter(&doc).map(|mat| mat.as_str()).collect()
189+
};
168190
let list = NGramList::new(words, self.n_gram_range());
169191
let document_vocabulary: HashSet<String> = list.into_iter().flatten().collect();
170192
for word in document_vocabulary {
@@ -253,13 +275,30 @@ impl CountVectorizer {
253275
self.vocabulary.len()
254276
}
255277

278+
pub fn force_tokenizer_redefinition(&mut self, tokenizer: Option<TOKENIZERFP>) {
279+
self.properties.tokenizer = tokenizer;
280+
}
281+
282+
pub(crate) fn validate_deserialization(&self) -> Result<()> {
283+
if self.properties.tokenizer().is_none() && self.properties.tokenizer_deserialization_guard
284+
{
285+
return Err(PreprocessingError::TokenizerNotSet);
286+
}
287+
288+
Ok(())
289+
}
290+
256291
/// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
257292
/// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
258293
/// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
259294
/// cell in the sparse matrix will be set to `None`.
260-
pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<usize> {
295+
pub fn transform<T: ToString, D: Data<Elem = T>>(
296+
&self,
297+
x: &ArrayBase<D, Ix1>,
298+
) -> Result<CsMat<usize>> {
299+
self.validate_deserialization()?;
261300
let (vectorized, _) = self.get_term_and_document_frequencies(x);
262-
vectorized
301+
Ok(vectorized)
263302
}
264303

265304
/// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
@@ -274,9 +313,10 @@ impl CountVectorizer {
274313
input: &[P],
275314
encoding: EncodingRef,
276315
trap: DecoderTrap,
277-
) -> CsMat<usize> {
316+
) -> Result<CsMat<usize>> {
317+
self.validate_deserialization()?;
278318
let (vectorized, _) = self.get_term_and_document_frequencies_files(input, encoding, trap);
279-
vectorized
319+
Ok(vectorized)
280320
}
281321

282322
/// Contains all vocabulary entries, in the same order used by the `transform` methods.
@@ -341,7 +381,11 @@ impl CountVectorizer {
341381
// in sparse cases.
342382
let mut term_frequencies: Array1<usize> = Array1::zeros(self.vocabulary.len());
343383
let string = transform_string(document, &self.properties);
344-
let words = regex.find_iter(&string).map(|mat| mat.as_str()).collect();
384+
let words = if let Some(tokenizer) = self.properties.tokenizer() {
385+
tokenizer(&string)
386+
} else {
387+
regex.find_iter(&string).map(|mat| mat.as_str()).collect()
388+
};
345389
let list = NGramList::new(words, self.properties.n_gram_range());
346390
for ngram_items in list {
347391
for item in ngram_items {
@@ -408,7 +452,7 @@ mod tests {
408452
let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
409453
let vectorizer = CountVectorizer::params().fit(&texts).unwrap();
410454
let vocabulary = vectorizer.vocabulary();
411-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
455+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
412456
let true_vocabulary = vec!["one", "two", "three", "four"];
413457
assert_vocabulary_eq(&true_vocabulary, vocabulary);
414458
assert_counts_for_word!(
@@ -425,7 +469,7 @@ mod tests {
425469
.fit(&texts)
426470
.unwrap();
427471
let vocabulary = vectorizer.vocabulary();
428-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
472+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
429473
let true_vocabulary = vec!["one two", "two three", "three four"];
430474
assert_vocabulary_eq(&true_vocabulary, vocabulary);
431475
assert_counts_for_word!(
@@ -441,7 +485,7 @@ mod tests {
441485
.fit(&texts)
442486
.unwrap();
443487
let vocabulary = vectorizer.vocabulary();
444-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
488+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
445489
let true_vocabulary = vec![
446490
"one",
447491
"one two",
@@ -479,7 +523,7 @@ mod tests {
479523
.unwrap();
480524
let vect_vocabulary = vectorizer.vocabulary();
481525
assert_vocabulary_eq(&vocabulary, vect_vocabulary);
482-
let transformed: Array2<usize> = vectorizer.transform(&texts).to_dense();
526+
let transformed: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
483527
assert_counts_for_word!(
484528
vect_vocabulary,
485529
transformed,
@@ -499,7 +543,7 @@ mod tests {
499543
.fit(&texts)
500544
.unwrap();
501545
let vocabulary = vectorizer.vocabulary();
502-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
546+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
503547
let true_vocabulary = vec!["one", "two", "three", "four", "three;four"];
504548
assert_vocabulary_eq(&true_vocabulary, vocabulary);
505549
assert_counts_for_word!(
@@ -521,7 +565,7 @@ mod tests {
521565
.fit(&texts)
522566
.unwrap();
523567
let vocabulary = vectorizer.vocabulary();
524-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
568+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
525569
let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO"];
526570
assert_vocabulary_eq(&true_vocabulary, vocabulary);
527571
assert_counts_for_word!(
@@ -549,7 +593,7 @@ mod tests {
549593
.fit(&texts)
550594
.unwrap();
551595
let vocabulary = vectorizer.vocabulary();
552-
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
596+
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
553597
let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO", "three;four"];
554598
assert_vocabulary_eq(&true_vocabulary, vocabulary);
555599
assert_counts_for_word!(
@@ -601,6 +645,7 @@ mod tests {
601645
encoding::all::UTF_8,
602646
encoding::DecoderTrap::Strict,
603647
)
648+
.unwrap()
604649
.to_dense();
605650
let true_vocabulary = vec!["one", "two", "three", "four"];
606651
assert_vocabulary_eq(&true_vocabulary, vocabulary);
@@ -628,6 +673,7 @@ mod tests {
628673
encoding::all::UTF_8,
629674
encoding::DecoderTrap::Strict,
630675
)
676+
.unwrap()
631677
.to_dense();
632678
let true_vocabulary = vec!["one two", "two three", "three four"];
633679
assert_vocabulary_eq(&true_vocabulary, vocabulary);
@@ -654,6 +700,7 @@ mod tests {
654700
encoding::all::UTF_8,
655701
encoding::DecoderTrap::Strict,
656702
)
703+
.unwrap()
657704
.to_dense();
658705
let true_vocabulary = vec![
659706
"one",

algorithms/linfa-preprocessing/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ pub enum PreprocessingError {
1414
#[error("not a valid float")]
1515
InvalidFloat,
1616
#[error("minimum value for MinMax scaler cannot be greater than the maximum")]
17+
TokenizerNotSet,
18+
#[error("Tokenizer must be defined after deserializing CountVectorizer by calling force_tokenizer_redefinition")]
1719
FlippedMinMaxRange,
1820
#[error("n_gram boundaries cannot be zero (min = {0}, max = {1})")]
1921
InvalidNGramBoundaries(usize, usize),

0 commit comments

Comments
 (0)