Skip to content

Commit 6b89449

Browse files
authored
Merge pull request #8 from koaning/utils
Made a utility func
2 parents 76b0883 + d13284a commit 6b89449

File tree

1 file changed

+18
-35
lines changed

1 file changed

+18
-35
lines changed

playtime/__init__.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from sklearn.base import clone
12
from sklearn.pipeline import make_pipeline, make_union
23
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, SplineTransformer
4+
from sklearn.compose import make_column_transformer
35
from sklearn.feature_extraction.text import CountVectorizer
46
from skrub import SelectCols
57
from .transformer_functions import column_pluck, datetime_feats
@@ -27,56 +29,37 @@ def select(*colnames):
2729
pipeline=make_pipeline(SelectCols([col for col in colnames]))
2830
)
2931

30-
def onehot(*colnames):
32+
def onehot(*colnames, **kwargs):
3133
"""One-hot encode specified columns, resulting in a sparse set of features."""
32-
return select(*colnames) | OneHotEncoder()
34+
return select(*colnames) | OneHotEncoder(**kwargs)
35+
36+
def minhash(*colnames, **kwargs):
37+
"""Create min-hash features for specified columns, resulting in a dense set of features."""
38+
from skrub import MinHashEncoder
39+
return estimator_for_all_columns(MinHashEncoder(**kwargs), *colnames)
3340

3441
def bag_of_words(*colnames, **kwargs):
3542
"""Generate bag-of-words features on a set of column, assuming it refers to text."""
43+
return estimator_for_all_columns(CountVectorizer(**kwargs), *colnames)
3644

37-
return PlaytimePipeline(
38-
pipeline=make_union(
39-
*[
40-
make_pipeline(
41-
FunctionTransformer(column_pluck, kw_args={"column": col}),
42-
CountVectorizer(**kwargs),
43-
)
44-
for col in colnames
45-
]
46-
)
47-
)
4845

4946
def embed_text(*colnames, name='all-MiniLM-L6-v2', **kwargs):
5047
"""Generate text embedding features on a set of columns, assuming it refers to text."""
5148
from embetter.text import SentenceEncoder
5249

53-
return PlaytimePipeline(
54-
pipeline=make_union(
55-
*[
56-
make_pipeline(
57-
FunctionTransformer(column_pluck, kw_args={"column": col}),
58-
SentenceEncoder(name),
59-
)
60-
for col in colnames
61-
]
62-
)
63-
)
50+
return estimator_for_all_columns(SentenceEncoder(name, **kwargs), *colnames)
6451

6552
def embed_image(*colnames):
6653
"""Generate image embedding features on a set of columns using CLIP, assuming it refers to an image path."""
6754
from embetter.grab import ColumnGrabber
6855
from embetter.vision import ImageLoader
6956
from embetter.multi import ClipEncoder
7057

71-
return PlaytimePipeline(
72-
pipeline=make_union(
73-
*[
74-
make_pipeline(
75-
FunctionTransformer(column_pluck, kw_args={"column": col}),
76-
ImageLoader(convert="RGB"),
77-
ClipEncoder()
78-
)
79-
for col in colnames
80-
]
81-
)
58+
est = make_pipeline(
59+
ImageLoader(convert="RGB"),
60+
ClipEncoder()
8261
)
62+
return estimator_for_all_columns(est, *colnames)
63+
64+
def estimator_for_all_columns(estimator, *columns):
65+
return PlaytimePipeline(make_column_transformer(*[(clone(estimator), col) for col in columns]))

0 commit comments

Comments
 (0)