Skip to content

Commit

Permalink
add max_features and tokenizer to CountVectorizer
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-cloudflare committed Feb 6, 2025
1 parent a30e5f1 commit 31a78fb
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 21 deletions.
1 change: 1 addition & 0 deletions algorithms/linfa-preprocessing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ encoding = "0.2"
sprs = { version = "=0.11.1", default-features = false }

serde_regex = { version = "1.1", optional = true }
itertools = "0.14.0"

[dependencies.serde_crate]
package = "serde"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ fn main() {
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
let training_records = vectorizer
.transform_files(&training_filenames, ISO_8859_1, Strict)
.unwrap()
.to_dense();
// Currently linfa only allows real valued features so we have to transform the integer counts to floats
let training_records = training_records.mapv(|c| c as f32);
Expand Down Expand Up @@ -164,6 +165,7 @@ fn main() {
);
let test_records = vectorizer
.transform_files(&test_filenames, ISO_8859_1, Strict)
.unwrap()
.to_dense();
let test_records = test_records.mapv(|c| c as f32);
let test_dataset: Dataset<f32, usize, Ix1> = (test_records, test_targets).into();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ fn main() {
// Transforming gives a sparse dataset, we make it dense in order to be able to fit the Naive Bayes model
let training_records = vectorizer
.transform_files(&training_filenames, ISO_8859_1, Strict)
.unwrap()
.to_dense();

println!(
Expand Down Expand Up @@ -162,6 +163,7 @@ fn main() {
);
let test_records = vectorizer
.transform_files(&test_filenames, ISO_8859_1, Strict)
.unwrap()
.to_dense();
let test_dataset: Dataset<f64, usize, Ix1> = (test_records, test_targets).into();
// Let's predict the test data targets
Expand Down
28 changes: 28 additions & 0 deletions algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::collections::HashSet;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use super::TOKENIZERFP;

#[derive(Clone, Debug)]
#[cfg(not(feature = "serde"))]
struct SerdeRegex(Regex);
Expand Down Expand Up @@ -71,9 +73,21 @@ pub struct CountVectorizerValidParams {
normalize: bool,
document_frequency: (f32, f32),
stopwords: Option<HashSet<String>>,
max_features: Option<usize>,
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) tokenizer: Option<TOKENIZERFP>,
pub(crate) tokenizer_deserialization_guard: bool,
}

impl CountVectorizerValidParams {
pub fn tokenizer(&self) -> Option<TOKENIZERFP> {
self.tokenizer
}

pub fn max_features(&self) -> Option<usize> {
self.max_features
}

pub fn convert_to_lowercase(&self) -> bool {
self.convert_to_lowercase
}
Expand Down Expand Up @@ -117,11 +131,25 @@ impl std::default::Default for CountVectorizerParams {
normalize: true,
document_frequency: (0., 1.),
stopwords: None,
max_features: None,
tokenizer: None,
tokenizer_deserialization_guard: false,
})
}
}

impl CountVectorizerParams {
pub fn tokenizer(mut self, tokenizer: Option<TOKENIZERFP>) -> Self {
self.0.tokenizer = tokenizer;
self.0.tokenizer_deserialization_guard = tokenizer.is_some();
self
}

pub fn max_features(mut self, max_features: Option<usize>) -> Self {
self.0.max_features = max_features;
self
}

///If true, all documents used for fitting will be converted to lowercase.
pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
self.0.convert_to_lowercase = convert_to_lowercase;
Expand Down
77 changes: 62 additions & 15 deletions algorithms/linfa-preprocessing/src/countgrams/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
//! Count vectorization methods
use std::cmp::Reverse;
use std::collections::{HashMap, HashSet};
use std::io::Read;
use std::iter::IntoIterator;

use encoding::types::EncodingRef;
use encoding::DecoderTrap;
use itertools::sorted;
use ndarray::{Array1, ArrayBase, ArrayViewMut1, Data, Ix1};
use regex::Regex;
use sprs::{CsMat, CsVec};
Expand All @@ -21,6 +23,8 @@ use serde_crate::{Deserialize, Serialize};

mod hyperparams;

pub(crate) type TOKENIZERFP = fn(&str) -> Vec<&str>;

impl CountVectorizerValidParams {
/// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
/// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
Expand All @@ -41,10 +45,11 @@ impl CountVectorizerValidParams {
}

let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());

let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);

Ok(CountVectorizer {
vocabulary,
vocabulary: vocabulary,
vec_vocabulary,
properties: self.clone(),
})
Expand Down Expand Up @@ -127,7 +132,7 @@ impl CountVectorizerValidParams {
let len_f32 = n_documents as f32;
let (min_abs_df, max_abs_df) = ((min_df * len_f32) as usize, (max_df * len_f32) as usize);

if min_abs_df == 0 && max_abs_df == n_documents {
let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
match &self.stopwords() {
None => vocabulary,
Some(stopwords) => vocabulary
Expand All @@ -152,6 +157,19 @@ impl CountVectorizerValidParams {
})
.collect(),
}
};

if let Some(max_features) = self.max_features() {
sorted(
vocabulary

Check warning on line 164 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L164

Added line #L164 was not covered by tests
.into_iter()
.map(|(word, (x, freq))| (Reverse(freq), Reverse(word), x)),

Check warning on line 166 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L166

Added line #L166 was not covered by tests
)
.take(max_features)
.map(|(freq, word, x)| (word.0, (x, freq.0)))

Check warning on line 169 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L169

Added line #L169 was not covered by tests
.collect()
} else {
vocabulary
}
}

Expand All @@ -164,7 +182,11 @@ impl CountVectorizerValidParams {
regex: &Regex,
vocabulary: &mut HashMap<String, (usize, usize)>,
) {
let words = regex.find_iter(&doc).map(|mat| mat.as_str()).collect();
let words = if let Some(tokenizer) = self.tokenizer() {
tokenizer(&doc)

Check warning on line 186 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L186

Added line #L186 was not covered by tests
} else {
regex.find_iter(&doc).map(|mat| mat.as_str()).collect()
};
let list = NGramList::new(words, self.n_gram_range());
let document_vocabulary: HashSet<String> = list.into_iter().flatten().collect();
for word in document_vocabulary {
Expand Down Expand Up @@ -253,13 +275,30 @@ impl CountVectorizer {
self.vocabulary.len()
}

pub fn force_tokenizer_redefinition(&mut self, tokenizer: Option<TOKENIZERFP>) {
self.properties.tokenizer = tokenizer;
}

pub(crate) fn validate_deserialization(&self) -> Result<()> {
if self.properties.tokenizer().is_none() && self.properties.tokenizer_deserialization_guard
{
return Err(PreprocessingError::TokenizerNotSet);
}

Ok(())
}

/// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
/// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
/// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
/// cell in the sparse matrix will be set to `None`.
pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<usize> {
pub fn transform<T: ToString, D: Data<Elem = T>>(

Check warning on line 295 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L295

Added line #L295 was not covered by tests
&self,
x: &ArrayBase<D, Ix1>,
) -> Result<CsMat<usize>> {
self.validate_deserialization()?;

Check warning on line 299 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L299

Added line #L299 was not covered by tests
let (vectorized, _) = self.get_term_and_document_frequencies(x);
vectorized
Ok(vectorized)
}

/// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
Expand All @@ -274,9 +313,10 @@ impl CountVectorizer {
input: &[P],
encoding: EncodingRef,
trap: DecoderTrap,
) -> CsMat<usize> {
) -> Result<CsMat<usize>> {
self.validate_deserialization()?;

Check warning on line 317 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L317

Added line #L317 was not covered by tests
let (vectorized, _) = self.get_term_and_document_frequencies_files(input, encoding, trap);
vectorized
Ok(vectorized)
}

/// Contains all vocabulary entries, in the same order used by the `transform` methods.
Expand Down Expand Up @@ -341,7 +381,11 @@ impl CountVectorizer {
// in sparse cases.
let mut term_frequencies: Array1<usize> = Array1::zeros(self.vocabulary.len());
let string = transform_string(document, &self.properties);
let words = regex.find_iter(&string).map(|mat| mat.as_str()).collect();
let words = if let Some(tokenizer) = self.properties.tokenizer() {
tokenizer(&string)

Check warning on line 385 in algorithms/linfa-preprocessing/src/countgrams/mod.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-preprocessing/src/countgrams/mod.rs#L385

Added line #L385 was not covered by tests
} else {
regex.find_iter(&string).map(|mat| mat.as_str()).collect()
};
let list = NGramList::new(words, self.properties.n_gram_range());
for ngram_items in list {
for item in ngram_items {
Expand Down Expand Up @@ -408,7 +452,7 @@ mod tests {
let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
let vectorizer = CountVectorizer::params().fit(&texts).unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec!["one", "two", "three", "four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
assert_counts_for_word!(
Expand All @@ -425,7 +469,7 @@ mod tests {
.fit(&texts)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec!["one two", "two three", "three four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
assert_counts_for_word!(
Expand All @@ -441,7 +485,7 @@ mod tests {
.fit(&texts)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec![
"one",
"one two",
Expand Down Expand Up @@ -479,7 +523,7 @@ mod tests {
.unwrap();
let vect_vocabulary = vectorizer.vocabulary();
assert_vocabulary_eq(&vocabulary, vect_vocabulary);
let transformed: Array2<usize> = vectorizer.transform(&texts).to_dense();
let transformed: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
assert_counts_for_word!(
vect_vocabulary,
transformed,
Expand All @@ -499,7 +543,7 @@ mod tests {
.fit(&texts)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec!["one", "two", "three", "four", "three;four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
assert_counts_for_word!(
Expand All @@ -521,7 +565,7 @@ mod tests {
.fit(&texts)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
assert_counts_for_word!(
Expand Down Expand Up @@ -549,7 +593,7 @@ mod tests {
.fit(&texts)
.unwrap();
let vocabulary = vectorizer.vocabulary();
let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO", "three;four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
assert_counts_for_word!(
Expand Down Expand Up @@ -601,6 +645,7 @@ mod tests {
encoding::all::UTF_8,
encoding::DecoderTrap::Strict,
)
.unwrap()
.to_dense();
let true_vocabulary = vec!["one", "two", "three", "four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
Expand Down Expand Up @@ -628,6 +673,7 @@ mod tests {
encoding::all::UTF_8,
encoding::DecoderTrap::Strict,
)
.unwrap()
.to_dense();
let true_vocabulary = vec!["one two", "two three", "three four"];
assert_vocabulary_eq(&true_vocabulary, vocabulary);
Expand All @@ -654,6 +700,7 @@ mod tests {
encoding::all::UTF_8,
encoding::DecoderTrap::Strict,
)
.unwrap()
.to_dense();
let true_vocabulary = vec![
"one",
Expand Down
2 changes: 2 additions & 0 deletions algorithms/linfa-preprocessing/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub enum PreprocessingError {
#[error("not a valid float")]
InvalidFloat,
#[error("minimum value for MinMax scaler cannot be greater than the maximum")]
TokenizerNotSet,
#[error("Tokenizer must be defined after deserializing CountVectorizer by calling force_tokenizer_redefinition")]
FlippedMinMaxRange,
#[error("n_gram boundaries cannot be zero (min = {0}, max = {1})")]
InvalidNGramBoundaries(usize, usize),
Expand Down
Loading

0 comments on commit 31a78fb

Please sign in to comment.