11//! Count vectorization methods
22
3+ use std:: cmp:: Reverse ;
34use std:: collections:: { HashMap , HashSet } ;
45use std:: io:: Read ;
56use std:: iter:: IntoIterator ;
67
78use encoding:: types:: EncodingRef ;
89use encoding:: DecoderTrap ;
10+ use itertools:: sorted;
911use ndarray:: { Array1 , ArrayBase , ArrayViewMut1 , Data , Ix1 } ;
1012use regex:: Regex ;
1113use sprs:: { CsMat , CsVec } ;
@@ -21,6 +23,8 @@ use serde_crate::{Deserialize, Serialize};
2123
2224mod hyperparams;
2325
26+ pub ( crate ) type TOKENIZERFP = fn ( & str ) -> Vec < & str > ;
27+
2428impl CountVectorizerValidParams {
2529 /// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
2630 /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
@@ -41,10 +45,11 @@ impl CountVectorizerValidParams {
4145 }
4246
4347 let mut vocabulary = self . filter_vocabulary ( vocabulary, x. len ( ) ) ;
48+
4449 let vec_vocabulary = hashmap_to_vocabulary ( & mut vocabulary) ;
4550
4651 Ok ( CountVectorizer {
47- vocabulary,
52+ vocabulary : vocabulary ,
4853 vec_vocabulary,
4954 properties : self . clone ( ) ,
5055 } )
@@ -127,7 +132,7 @@ impl CountVectorizerValidParams {
127132 let len_f32 = n_documents as f32 ;
128133 let ( min_abs_df, max_abs_df) = ( ( min_df * len_f32) as usize , ( max_df * len_f32) as usize ) ;
129134
130- if min_abs_df == 0 && max_abs_df == n_documents {
135+ let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
131136 match & self . stopwords ( ) {
132137 None => vocabulary,
133138 Some ( stopwords) => vocabulary
@@ -152,6 +157,19 @@ impl CountVectorizerValidParams {
152157 } )
153158 . collect ( ) ,
154159 }
160+ } ;
161+
162+ if let Some ( max_features) = self . max_features ( ) {
163+ sorted (
164+ vocabulary
165+ . into_iter ( )
166+ . map ( |( word, ( x, freq) ) | ( Reverse ( freq) , Reverse ( word) , x) ) ,
167+ )
168+ . take ( max_features)
169+ . map ( |( freq, word, x) | ( word. 0 , ( x, freq. 0 ) ) )
170+ . collect ( )
171+ } else {
172+ vocabulary
155173 }
156174 }
157175
@@ -164,7 +182,11 @@ impl CountVectorizerValidParams {
164182 regex : & Regex ,
165183 vocabulary : & mut HashMap < String , ( usize , usize ) > ,
166184 ) {
167- let words = regex. find_iter ( & doc) . map ( |mat| mat. as_str ( ) ) . collect ( ) ;
185+ let words = if let Some ( tokenizer) = self . tokenizer ( ) {
186+ tokenizer ( & doc)
187+ } else {
188+ regex. find_iter ( & doc) . map ( |mat| mat. as_str ( ) ) . collect ( )
189+ } ;
168190 let list = NGramList :: new ( words, self . n_gram_range ( ) ) ;
169191 let document_vocabulary: HashSet < String > = list. into_iter ( ) . flatten ( ) . collect ( ) ;
170192 for word in document_vocabulary {
@@ -253,13 +275,30 @@ impl CountVectorizer {
253275 self . vocabulary . len ( )
254276 }
255277
278+ pub fn force_tokenizer_redefinition ( & mut self , tokenizer : Option < TOKENIZERFP > ) {
279+ self . properties . tokenizer = tokenizer;
280+ }
281+
282+ pub ( crate ) fn validate_deserialization ( & self ) -> Result < ( ) > {
283+ if self . properties . tokenizer ( ) . is_none ( ) && self . properties . tokenizer_deserialization_guard
284+ {
285+ return Err ( PreprocessingError :: TokenizerNotSet ) ;
286+ }
287+
288+ Ok ( ( ) )
289+ }
290+
256291 /// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
257292 /// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
258293 /// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
259294 /// cell in the sparse matrix will be set to `None`.
260- pub fn transform < T : ToString , D : Data < Elem = T > > ( & self , x : & ArrayBase < D , Ix1 > ) -> CsMat < usize > {
295+ pub fn transform < T : ToString , D : Data < Elem = T > > (
296+ & self ,
297+ x : & ArrayBase < D , Ix1 > ,
298+ ) -> Result < CsMat < usize > > {
299+ self . validate_deserialization ( ) ?;
261300 let ( vectorized, _) = self . get_term_and_document_frequencies ( x) ;
262- vectorized
301+ Ok ( vectorized)
263302 }
264303
265304 /// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
@@ -274,9 +313,10 @@ impl CountVectorizer {
274313 input : & [ P ] ,
275314 encoding : EncodingRef ,
276315 trap : DecoderTrap ,
277- ) -> CsMat < usize > {
316+ ) -> Result < CsMat < usize > > {
317+ self . validate_deserialization ( ) ?;
278318 let ( vectorized, _) = self . get_term_and_document_frequencies_files ( input, encoding, trap) ;
279- vectorized
319+ Ok ( vectorized)
280320 }
281321
282322 /// Contains all vocabulary entries, in the same order used by the `transform` methods.
@@ -341,7 +381,11 @@ impl CountVectorizer {
341381 // in sparse cases.
342382 let mut term_frequencies: Array1 < usize > = Array1 :: zeros ( self . vocabulary . len ( ) ) ;
343383 let string = transform_string ( document, & self . properties ) ;
344- let words = regex. find_iter ( & string) . map ( |mat| mat. as_str ( ) ) . collect ( ) ;
384+ let words = if let Some ( tokenizer) = self . properties . tokenizer ( ) {
385+ tokenizer ( & string)
386+ } else {
387+ regex. find_iter ( & string) . map ( |mat| mat. as_str ( ) ) . collect ( )
388+ } ;
345389 let list = NGramList :: new ( words, self . properties . n_gram_range ( ) ) ;
346390 for ngram_items in list {
347391 for item in ngram_items {
@@ -408,7 +452,7 @@ mod tests {
408452 let texts = array ! [ "oNe two three four" , "TWO three four" , "three;four" , "four" ] ;
409453 let vectorizer = CountVectorizer :: params ( ) . fit ( & texts) . unwrap ( ) ;
410454 let vocabulary = vectorizer. vocabulary ( ) ;
411- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
455+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
412456 let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" ] ;
413457 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
414458 assert_counts_for_word ! (
@@ -425,7 +469,7 @@ mod tests {
425469 . fit ( & texts)
426470 . unwrap ( ) ;
427471 let vocabulary = vectorizer. vocabulary ( ) ;
428- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
472+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
429473 let true_vocabulary = vec ! [ "one two" , "two three" , "three four" ] ;
430474 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
431475 assert_counts_for_word ! (
@@ -441,7 +485,7 @@ mod tests {
441485 . fit ( & texts)
442486 . unwrap ( ) ;
443487 let vocabulary = vectorizer. vocabulary ( ) ;
444- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
488+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
445489 let true_vocabulary = vec ! [
446490 "one" ,
447491 "one two" ,
@@ -479,7 +523,7 @@ mod tests {
479523 . unwrap ( ) ;
480524 let vect_vocabulary = vectorizer. vocabulary ( ) ;
481525 assert_vocabulary_eq ( & vocabulary, vect_vocabulary) ;
482- let transformed: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
526+ let transformed: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
483527 assert_counts_for_word ! (
484528 vect_vocabulary,
485529 transformed,
@@ -499,7 +543,7 @@ mod tests {
499543 . fit ( & texts)
500544 . unwrap ( ) ;
501545 let vocabulary = vectorizer. vocabulary ( ) ;
502- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
546+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
503547 let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" , "three;four" ] ;
504548 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
505549 assert_counts_for_word ! (
@@ -521,7 +565,7 @@ mod tests {
521565 . fit ( & texts)
522566 . unwrap ( ) ;
523567 let vocabulary = vectorizer. vocabulary ( ) ;
524- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
568+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
525569 let true_vocabulary = vec ! [ "oNe" , "two" , "three" , "four" , "TWO" ] ;
526570 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
527571 assert_counts_for_word ! (
@@ -549,7 +593,7 @@ mod tests {
549593 . fit ( & texts)
550594 . unwrap ( ) ;
551595 let vocabulary = vectorizer. vocabulary ( ) ;
552- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
596+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
553597 let true_vocabulary = vec ! [ "oNe" , "two" , "three" , "four" , "TWO" , "three;four" ] ;
554598 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
555599 assert_counts_for_word ! (
@@ -601,6 +645,7 @@ mod tests {
601645 encoding:: all:: UTF_8 ,
602646 encoding:: DecoderTrap :: Strict ,
603647 )
648+ . unwrap ( )
604649 . to_dense ( ) ;
605650 let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" ] ;
606651 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
@@ -628,6 +673,7 @@ mod tests {
628673 encoding:: all:: UTF_8 ,
629674 encoding:: DecoderTrap :: Strict ,
630675 )
676+ . unwrap ( )
631677 . to_dense ( ) ;
632678 let true_vocabulary = vec ! [ "one two" , "two three" , "three four" ] ;
633679 assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
@@ -654,6 +700,7 @@ mod tests {
654700 encoding:: all:: UTF_8 ,
655701 encoding:: DecoderTrap :: Strict ,
656702 )
703+ . unwrap ( )
657704 . to_dense ( ) ;
658705 let true_vocabulary = vec ! [
659706 "one" ,
0 commit comments