|
| 1 | +from sklearn.base import clone |
1 | 2 | from sklearn.pipeline import make_pipeline, make_union |
2 | 3 | from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, SplineTransformer |
| 4 | +from sklearn.compose import make_column_transformer |
3 | 5 | from sklearn.feature_extraction.text import CountVectorizer |
4 | 6 | from skrub import SelectCols |
5 | 7 | from .transformer_functions import column_pluck, datetime_feats |
@@ -27,56 +29,37 @@ def select(*colnames): |
27 | 29 | pipeline=make_pipeline(SelectCols([col for col in colnames])) |
28 | 30 | ) |
29 | 31 |
|
30 | | -def onehot(*colnames): |
| 32 | +def onehot(*colnames, **kwargs): |
31 | 33 | """One-hot encode specified columns, resulting in a sparse set of features.""" |
32 | | - return select(*colnames) | OneHotEncoder() |
| 34 | + return select(*colnames) | OneHotEncoder(**kwargs) |
| 35 | + |
| 36 | +def minhash(*colnames, **kwargs): |
| 37 | + """Create min-hash features for specified columns, resulting in a dense set of features.""" |
| 38 | + from skrub import MinHashEncoder |
| 39 | + return estimator_for_all_columns(MinHashEncoder(**kwargs), *colnames) |
33 | 40 |
|
34 | 41 | def bag_of_words(*colnames, **kwargs): |
35 | 42 | """Generate bag-of-words features on a set of column, assuming it refers to text.""" |
| 43 | + return estimator_for_all_columns(CountVectorizer(**kwargs), *colnames) |
36 | 44 |
|
37 | | - return PlaytimePipeline( |
38 | | - pipeline=make_union( |
39 | | - *[ |
40 | | - make_pipeline( |
41 | | - FunctionTransformer(column_pluck, kw_args={"column": col}), |
42 | | - CountVectorizer(**kwargs), |
43 | | - ) |
44 | | - for col in colnames |
45 | | - ] |
46 | | - ) |
47 | | - ) |
48 | 45 |
|
49 | 46 | def embed_text(*colnames, name='all-MiniLM-L6-v2', **kwargs): |
50 | 47 | """Generate text embedding features on a set of columns, assuming it refers to text.""" |
51 | 48 | from embetter.text import SentenceEncoder |
52 | 49 |
|
53 | | - return PlaytimePipeline( |
54 | | - pipeline=make_union( |
55 | | - *[ |
56 | | - make_pipeline( |
57 | | - FunctionTransformer(column_pluck, kw_args={"column": col}), |
58 | | - SentenceEncoder(name), |
59 | | - ) |
60 | | - for col in colnames |
61 | | - ] |
62 | | - ) |
63 | | - ) |
| 50 | + return estimator_for_all_columns(SentenceEncoder(name, **kwargs), *colnames) |
64 | 51 |
|
65 | 52 | def embed_image(*colnames): |
66 | 53 | """Generate image embedding features on a set of columns using CLIP, assuming it refers to an image path.""" |
67 | 54 | from embetter.grab import ColumnGrabber |
68 | 55 | from embetter.vision import ImageLoader |
69 | 56 | from embetter.multi import ClipEncoder |
70 | 57 |
|
71 | | - return PlaytimePipeline( |
72 | | - pipeline=make_union( |
73 | | - *[ |
74 | | - make_pipeline( |
75 | | - FunctionTransformer(column_pluck, kw_args={"column": col}), |
76 | | - ImageLoader(convert="RGB"), |
77 | | - ClipEncoder() |
78 | | - ) |
79 | | - for col in colnames |
80 | | - ] |
81 | | - ) |
| 58 | + est = make_pipeline( |
| 59 | + ImageLoader(convert="RGB"), |
| 60 | + ClipEncoder() |
82 | 61 | ) |
| 62 | + return estimator_for_all_columns(est, *colnames) |
| 63 | + |
| 64 | +def estimator_for_all_columns(estimator, *columns): |
| 65 | + return PlaytimePipeline(make_column_transformer(*[(clone(estimator), col) for col in columns])) |
0 commit comments