1
1
//! Count vectorization methods
2
2
3
+ use std:: cmp:: Reverse ;
3
4
use std:: collections:: { HashMap , HashSet } ;
4
5
use std:: io:: Read ;
5
6
use std:: iter:: IntoIterator ;
6
7
7
8
use encoding:: types:: EncodingRef ;
8
9
use encoding:: DecoderTrap ;
10
+ use itertools:: sorted;
9
11
use ndarray:: { Array1 , ArrayBase , ArrayViewMut1 , Data , Ix1 } ;
10
12
use regex:: Regex ;
11
13
use sprs:: { CsMat , CsVec } ;
@@ -21,6 +23,8 @@ use serde_crate::{Deserialize, Serialize};
21
23
22
24
mod hyperparams;
23
25
26
+ pub ( crate ) type TOKENIZERFP = fn ( & str ) -> Vec < & str > ;
27
+
24
28
impl CountVectorizerValidParams {
25
29
/// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
26
30
/// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
@@ -41,10 +45,11 @@ impl CountVectorizerValidParams {
41
45
}
42
46
43
47
let mut vocabulary = self . filter_vocabulary ( vocabulary, x. len ( ) ) ;
48
+
44
49
let vec_vocabulary = hashmap_to_vocabulary ( & mut vocabulary) ;
45
50
46
51
Ok ( CountVectorizer {
47
- vocabulary,
52
+ vocabulary : vocabulary ,
48
53
vec_vocabulary,
49
54
properties : self . clone ( ) ,
50
55
} )
@@ -127,7 +132,7 @@ impl CountVectorizerValidParams {
127
132
let len_f32 = n_documents as f32 ;
128
133
let ( min_abs_df, max_abs_df) = ( ( min_df * len_f32) as usize , ( max_df * len_f32) as usize ) ;
129
134
130
- if min_abs_df == 0 && max_abs_df == n_documents {
135
+ let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
131
136
match & self . stopwords ( ) {
132
137
None => vocabulary,
133
138
Some ( stopwords) => vocabulary
@@ -152,6 +157,19 @@ impl CountVectorizerValidParams {
152
157
} )
153
158
. collect ( ) ,
154
159
}
160
+ } ;
161
+
162
+ if let Some ( max_features) = self . max_features ( ) {
163
+ sorted (
164
+ vocabulary
165
+ . into_iter ( )
166
+ . map ( |( word, ( x, freq) ) | ( Reverse ( freq) , Reverse ( word) , x) ) ,
167
+ )
168
+ . take ( max_features)
169
+ . map ( |( freq, word, x) | ( word. 0 , ( x, freq. 0 ) ) )
170
+ . collect ( )
171
+ } else {
172
+ vocabulary
155
173
}
156
174
}
157
175
@@ -164,7 +182,11 @@ impl CountVectorizerValidParams {
164
182
regex : & Regex ,
165
183
vocabulary : & mut HashMap < String , ( usize , usize ) > ,
166
184
) {
167
- let words = regex. find_iter ( & doc) . map ( |mat| mat. as_str ( ) ) . collect ( ) ;
185
+ let words = if let Some ( tokenizer) = self . tokenizer ( ) {
186
+ tokenizer ( & doc)
187
+ } else {
188
+ regex. find_iter ( & doc) . map ( |mat| mat. as_str ( ) ) . collect ( )
189
+ } ;
168
190
let list = NGramList :: new ( words, self . n_gram_range ( ) ) ;
169
191
let document_vocabulary: HashSet < String > = list. into_iter ( ) . flatten ( ) . collect ( ) ;
170
192
for word in document_vocabulary {
@@ -253,13 +275,30 @@ impl CountVectorizer {
253
275
self . vocabulary . len ( )
254
276
}
255
277
278
+ pub fn force_tokenizer_redefinition ( & mut self , tokenizer : Option < TOKENIZERFP > ) {
279
+ self . properties . tokenizer = tokenizer;
280
+ }
281
+
282
+ pub ( crate ) fn validate_deserialization ( & self ) -> Result < ( ) > {
283
+ if self . properties . tokenizer ( ) . is_none ( ) && self . properties . tokenizer_deserialization_guard
284
+ {
285
+ return Err ( PreprocessingError :: TokenizerNotSet ) ;
286
+ }
287
+
288
+ Ok ( ( ) )
289
+ }
290
+
256
291
/// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
257
292
/// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
258
293
/// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
259
294
/// cell in the sparse matrix will be set to `None`.
260
- pub fn transform < T : ToString , D : Data < Elem = T > > ( & self , x : & ArrayBase < D , Ix1 > ) -> CsMat < usize > {
295
+ pub fn transform < T : ToString , D : Data < Elem = T > > (
296
+ & self ,
297
+ x : & ArrayBase < D , Ix1 > ,
298
+ ) -> Result < CsMat < usize > > {
299
+ self . validate_deserialization ( ) ?;
261
300
let ( vectorized, _) = self . get_term_and_document_frequencies ( x) ;
262
- vectorized
301
+ Ok ( vectorized)
263
302
}
264
303
265
304
/// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
@@ -274,9 +313,10 @@ impl CountVectorizer {
274
313
input : & [ P ] ,
275
314
encoding : EncodingRef ,
276
315
trap : DecoderTrap ,
277
- ) -> CsMat < usize > {
316
+ ) -> Result < CsMat < usize > > {
317
+ self . validate_deserialization ( ) ?;
278
318
let ( vectorized, _) = self . get_term_and_document_frequencies_files ( input, encoding, trap) ;
279
- vectorized
319
+ Ok ( vectorized)
280
320
}
281
321
282
322
/// Contains all vocabulary entries, in the same order used by the `transform` methods.
@@ -341,7 +381,11 @@ impl CountVectorizer {
341
381
// in sparse cases.
342
382
let mut term_frequencies: Array1 < usize > = Array1 :: zeros ( self . vocabulary . len ( ) ) ;
343
383
let string = transform_string ( document, & self . properties ) ;
344
- let words = regex. find_iter ( & string) . map ( |mat| mat. as_str ( ) ) . collect ( ) ;
384
+ let words = if let Some ( tokenizer) = self . properties . tokenizer ( ) {
385
+ tokenizer ( & string)
386
+ } else {
387
+ regex. find_iter ( & string) . map ( |mat| mat. as_str ( ) ) . collect ( )
388
+ } ;
345
389
let list = NGramList :: new ( words, self . properties . n_gram_range ( ) ) ;
346
390
for ngram_items in list {
347
391
for item in ngram_items {
@@ -408,7 +452,7 @@ mod tests {
408
452
let texts = array ! [ "oNe two three four" , "TWO three four" , "three;four" , "four" ] ;
409
453
let vectorizer = CountVectorizer :: params ( ) . fit ( & texts) . unwrap ( ) ;
410
454
let vocabulary = vectorizer. vocabulary ( ) ;
411
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
455
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
412
456
let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" ] ;
413
457
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
414
458
assert_counts_for_word ! (
@@ -425,7 +469,7 @@ mod tests {
425
469
. fit ( & texts)
426
470
. unwrap ( ) ;
427
471
let vocabulary = vectorizer. vocabulary ( ) ;
428
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
472
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
429
473
let true_vocabulary = vec ! [ "one two" , "two three" , "three four" ] ;
430
474
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
431
475
assert_counts_for_word ! (
@@ -441,7 +485,7 @@ mod tests {
441
485
. fit ( & texts)
442
486
. unwrap ( ) ;
443
487
let vocabulary = vectorizer. vocabulary ( ) ;
444
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
488
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
445
489
let true_vocabulary = vec ! [
446
490
"one" ,
447
491
"one two" ,
@@ -479,7 +523,7 @@ mod tests {
479
523
. unwrap ( ) ;
480
524
let vect_vocabulary = vectorizer. vocabulary ( ) ;
481
525
assert_vocabulary_eq ( & vocabulary, vect_vocabulary) ;
482
- let transformed: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
526
+ let transformed: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
483
527
assert_counts_for_word ! (
484
528
vect_vocabulary,
485
529
transformed,
@@ -499,7 +543,7 @@ mod tests {
499
543
. fit ( & texts)
500
544
. unwrap ( ) ;
501
545
let vocabulary = vectorizer. vocabulary ( ) ;
502
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
546
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
503
547
let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" , "three;four" ] ;
504
548
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
505
549
assert_counts_for_word ! (
@@ -521,7 +565,7 @@ mod tests {
521
565
. fit ( & texts)
522
566
. unwrap ( ) ;
523
567
let vocabulary = vectorizer. vocabulary ( ) ;
524
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
568
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
525
569
let true_vocabulary = vec ! [ "oNe" , "two" , "three" , "four" , "TWO" ] ;
526
570
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
527
571
assert_counts_for_word ! (
@@ -549,7 +593,7 @@ mod tests {
549
593
. fit ( & texts)
550
594
. unwrap ( ) ;
551
595
let vocabulary = vectorizer. vocabulary ( ) ;
552
- let counts: Array2 < usize > = vectorizer. transform ( & texts) . to_dense ( ) ;
596
+ let counts: Array2 < usize > = vectorizer. transform ( & texts) . unwrap ( ) . to_dense ( ) ;
553
597
let true_vocabulary = vec ! [ "oNe" , "two" , "three" , "four" , "TWO" , "three;four" ] ;
554
598
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
555
599
assert_counts_for_word ! (
@@ -601,6 +645,7 @@ mod tests {
601
645
encoding:: all:: UTF_8 ,
602
646
encoding:: DecoderTrap :: Strict ,
603
647
)
648
+ . unwrap ( )
604
649
. to_dense ( ) ;
605
650
let true_vocabulary = vec ! [ "one" , "two" , "three" , "four" ] ;
606
651
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
@@ -628,6 +673,7 @@ mod tests {
628
673
encoding:: all:: UTF_8 ,
629
674
encoding:: DecoderTrap :: Strict ,
630
675
)
676
+ . unwrap ( )
631
677
. to_dense ( ) ;
632
678
let true_vocabulary = vec ! [ "one two" , "two three" , "three four" ] ;
633
679
assert_vocabulary_eq ( & true_vocabulary, vocabulary) ;
@@ -654,6 +700,7 @@ mod tests {
654
700
encoding:: all:: UTF_8 ,
655
701
encoding:: DecoderTrap :: Strict ,
656
702
)
703
+ . unwrap ( )
657
704
. to_dense ( ) ;
658
705
let true_vocabulary = vec ! [
659
706
"one" ,
0 commit comments