Skip to content

Commit fdd07b0

Browse files
authored
Merge pull request #12 from adrinjalali/SelectCol
MNT remove skrub dep and have our own SelectCols
2 parents 1d2051f + 7ef6e76 commit fdd07b0

File tree

5 files changed

+97
-5
lines changed

5 files changed

+97
-5
lines changed

.github/workflows/unittest.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ jobs:
2929
source .venv/bin/activate
3030
which python
3131
python --version
32-
uv pip install -e . pytest
32+
uv pip install -e .[test]
3333
uv run pytest

playtime/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, SplineTransformer
44
from sklearn.compose import make_column_transformer
55
from sklearn.feature_extraction.text import CountVectorizer
6-
from skrub import SelectCols
6+
from .estimators import SelectCols
77
from .transformer_functions import datetime_feats
88
from .formula import PlaytimePipeline
99
from typing import Any
@@ -41,9 +41,7 @@ def select(*colnames: str) -> PlaytimePipeline:
4141
pipeline = select("col_a", "col_b")
4242
```
4343
"""
44-
return PlaytimePipeline(
45-
pipeline=make_pipeline(SelectCols([col for col in colnames]))
46-
)
44+
return PlaytimePipeline(pipeline=make_pipeline(SelectCols(colnames)))
4745

4846

4947
def onehot(*colnames: str, **kwargs) -> PlaytimePipeline:

playtime/estimators.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import narwhals as nw
2+
from sklearn.base import BaseEstimator, TransformerMixin
3+
4+
5+
class SelectCols(TransformerMixin, BaseEstimator):
6+
"""Select a subset of a DataFrame's columns.
7+
8+
A ``ValueError`` is raised if any of the provided column names are not in
9+
the dataframe.
10+
11+
Accepts anything accepted by :func:`narwhals.from_native(..., eager_only=True)`.
12+
13+
Arguments
14+
cols : list of str or str
15+
The columns to select. A single column name can be passed as a ``str``:
16+
``"col_name"`` is the same as ``["col_name"]``.
17+
18+
**Usage**
19+
```python
20+
>>> import polars as pl
21+
>>> from playtime.estimators import SelectCols
22+
>>> df = pl.DataFrame({"A": [1, 2], "B": [10, 20], "C": ["x", "y"]})
23+
>>> df
24+
A B C
25+
0 1 10 x
26+
1 2 20 y
27+
>>> SelectCols(["C", "A"]).fit_transform(df)
28+
C A
29+
0 x 1
30+
1 y 2
31+
>>> SelectCols(["X", "A"]).fit_transform(df)
32+
Traceback (most recent call last):
33+
...
34+
ValueError: The following columns are requested for selection but missing from dataframe: ['X']
35+
```
36+
"""
37+
38+
def __init__(self, cols=None):
39+
self.cols = cols
40+
41+
def fit(self, X, y=None):
42+
"""Fit the transformer.
43+
44+
Arguments
45+
X : DataFrame or None
46+
If `X` is a DataFrame, the transformer checks that all the column
47+
names provided in ``self.cols`` can be found in `X`.
48+
49+
y : None
50+
Unused.
51+
52+
Returns
53+
-------
54+
SelectCols
55+
The transformer itself.
56+
"""
57+
nw.from_native(X, eager_only=True).select(self.cols)
58+
return self
59+
60+
def transform(self, X):
61+
"""Transform a dataframe by selecting columns.
62+
63+
Parameters
64+
----------
65+
X : DataFrame
66+
The DataFrame on which to apply the selection.
67+
68+
Returns
69+
-------
70+
DataFrame
71+
The input DataFrame ``X`` after selecting only the columns listed
72+
in ``self.cols`` (in the provided order).
73+
"""
74+
return nw.from_native(X, eager_only=True).select(self.cols)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
lint = ["pre-commit"]
38+
test = ["pytest", "pandas"]
3839

3940
[project.urls]
4041
repository = "https://github.com/koaning/playtime"

tests/test_estimators.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from playtime.estimators import SelectCols
4+
5+
6+
@pytest.mark.parametrize("lib", ["pandas", "polars"])
7+
def test_select_cols(lib):
8+
lib = pytest.importorskip(lib)
9+
10+
df = lib.DataFrame({"a": [1, 2], "b": [10, 20], "c": ["x", "y"]})
11+
tfm = SelectCols(["a", "b"])
12+
# This should work w/o calling `fit`
13+
assert tfm.transform(df).columns == ["a", "b"]
14+
# This should also pass
15+
tfm.fit(df)
16+
17+
# In reality it's either a KeyError or a pl.exceptions.ColumnNotFoundError
18+
with pytest.raises(Exception):
19+
SelectCols(["a", "b", "d"]).fit(df)

0 commit comments

Comments
 (0)