Skip to content

Commit 4e5d063

Browse files
committed
Merge branch 'main' into feature/ohe_values_2
2 parents ced7398 + 6fc1124 commit 4e5d063

15 files changed

+737
-564
lines changed

CHANGELOG.rst

+10-2
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,21 @@ Changed
3737
- narwhalified BaseNumericTransformer `#358 https://github.com/lvgig/tubular/issues/358`
3838
- narwhalified DropOriginalMixin `#352 <https://github.com/lvgig/tubular/issues/352>_`
3939
- narwhalified BaseMappingTransformer `#367 <https://github.com/lvgig/tubular/issues/367>_`
40+
- narwhalified BaseMappingTransformerMixin. As part of this made mapping transformers more
41+
type-conscious, they now rely on an input 'return_dtypes' dict arg.
42+
`#369 <https://github.com/lvgig/tubular/issues/369>_`
43+
- As part of #369, updated OrdinalEncoderTransformer to output Int8 type
44+
- As part of #369, updated NominalToIntegerTransformer to output Int8 type. Removed inverse_mapping
45+
functionality, as this is more complicated when transform is opinionated on types.
46+
- narwhalified GroupRareLevelsTransformer. As part of this, had to make transformer more opinionated
47+
and refuse columns with nulls (raises an error directing to imputers.) `#372 <https://github.com/lvgig/tubular/issues/372>_`
4048
- narwhalified BaseDatetimeTransformer `#375 <https://github.com/azukds/tubular/issues/375>`
4149
- Optional wanted_levels feature has been integrated into the OneHotEncodingTransformer which allows users to specify which levels in a column they wish to encode. `#384 <https://github.com/azukds/tubular/issues/384>_`
4250
- Created unit tests to check if the values provided for wanted_values are as expected and if the output is as expected.
4351
- placeholder
4452
- placeholder
45-
46-
53+
- placeholder
54+
- placeholder
4755

4856
1.4.1 (02/12/2024)
4957
------------------

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ dynamic = ["version"]
1818
dependencies = [
1919
"pandas>=1.5.0",
2020
"scikit-learn>=1.2.0",
21-
"narwhals >= 1.17.0",
21+
"narwhals >= 1.21.1",
2222
"polars >= 1.9.0",
23+
"beartype >= 0.19.0",
2324
]
2425
requires-python = ">=3.9"
2526
authors = [{ name = "LV GI Data Science Team", email="#[email protected]"}]
@@ -130,6 +131,8 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
130131
"__init__.py" = ["E402", "F401"]
131132
"tests/*" = ["ANN", "S101"]
132133

134+
[tool.ruff.lint.pyupgrade]
135+
keep-runtime-typing=true
133136

134137
[tool.coverage.run]
135138
branch = true

requirements-dev.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# pip-compile --extra=dev --no-emit-index-url --no-emit-trusted-host --output-file=requirements-dev.txt pyproject.toml
66
#
7+
beartype==0.19.0
8+
# via tubular (pyproject.toml)
79
cfgv==3.4.0
810
# via pre-commit
911
coverage[toml]==7.6.0
@@ -20,7 +22,7 @@ iniconfig==2.0.0
2022
# via pytest
2123
joblib==1.4.2
2224
# via scikit-learn
23-
narwhals==1.17.0
25+
narwhals==1.21.1
2426
# via tubular (pyproject.toml)
2527
nodeenv==1.9.1
2628
# via pre-commit

tests/base_tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def test_unexpected_kwarg_error(
453453
),
454454
):
455455
uninitialized_transformers[self.transformer_name](
456-
unexpected_kwarg="spanish inquisition",
456+
unexpected_kwarg=True,
457457
**minimal_attribute_dict[self.transformer_name],
458458
)
459459

tests/conftest.py

+6
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def minimal_attribute_dict():
127127
},
128128
"BaseMappingTransformer": {
129129
"mappings": {"a": {1: 2, 3: 4}},
130+
# this arg is dependent on inputs, so
131+
# may not generalise well
132+
"return_dtypes": {"a": "Int8"},
130133
},
131134
"BaseMappingTransformMixin": {
132135
"columns": ["a"],
@@ -208,6 +211,9 @@ def minimal_attribute_dict():
208211
},
209212
"MappingTransformer": {
210213
"mappings": {"a": {1: 2, 3: 4}},
214+
# this arg is dependent on inputs, so
215+
# may not generalise well
216+
"return_dtypes": {"a": "Int8"},
211217
},
212218
"MeanImputer": {
213219
"columns": ["b"],

tests/mapping/test_BaseCrossColumnMappingTransformer.py

+29
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@ def test_adjust_columns_non_string_error(
3131
):
3232
uninitialized_transformers[self.transformer_name](**args)
3333

34+
def test_inferred_return_dtypes(
35+
self,
36+
uninitialized_transformers,
37+
minimal_attribute_dict,
38+
):
39+
"""test that return_dtypes are inferred correctly if not provided - test is
40+
overloaded as these transformers require OrderedDict for multicolumn mapping
41+
"""
42+
43+
kwargs = minimal_attribute_dict[self.transformer_name]
44+
kwargs["mappings"] = {
45+
"a": {"a": 1, "b": 2},
46+
}
47+
kwargs["return_dtypes"] = None
48+
49+
transformer = uninitialized_transformers[self.transformer_name](
50+
**kwargs,
51+
)
52+
53+
expected = {
54+
"a": "Int64",
55+
}
56+
57+
actual = transformer.return_dtypes
58+
59+
assert (
60+
actual == expected
61+
), f"return_dtypes attr not inferred as expected, expected {expected} but got {actual}"
62+
3463

3564
class BaseCrossColumnMappingTransformerTransformTests(
3665
BaseMappingTransformerTransformTests,

tests/mapping/test_BaseCrossColumnNumericTransformer.py

+31
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,37 @@ def test_mapping_values_not_numeric_error(
3636
):
3737
uninitialized_transformers[self.transformer_name](**args)
3838

39+
def test_inferred_return_dtypes(
40+
self,
41+
uninitialized_transformers,
42+
minimal_attribute_dict,
43+
):
44+
"""test that return_dtypes are inferred correctly if not provided - test is
45+
overloaded as these transformers can only handle numeric types
46+
"""
47+
48+
kwargs = minimal_attribute_dict[self.transformer_name]
49+
kwargs["mappings"] = {
50+
"a": {"a": 1, "b": 2},
51+
"c": {"d": 1.0, "e": 2.0},
52+
}
53+
kwargs["return_dtypes"] = None
54+
55+
transformer = uninitialized_transformers[self.transformer_name](
56+
**kwargs,
57+
)
58+
59+
expected = {
60+
"a": "Int64",
61+
"c": "Float64",
62+
}
63+
64+
actual = transformer.return_dtypes
65+
66+
assert (
67+
actual == expected
68+
), f"return_dtypes attr not inferred as expected, expected {expected} but got {actual}"
69+
3970

4071
class BaseCrossColumnNumericTransformerTransformTests(
4172
BaseCrossColumnMappingTransformerTransformTests,

tests/mapping/test_BaseMappingTransformMixin.py

+89-50
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import re
33

44
import pandas as pd
5+
import polars as pl
56
import pytest
6-
import test_aide as ta
77

88
import tests.test_data as d
99
from tests.base_tests import (
@@ -12,7 +12,7 @@
1212
GenericTransformTests,
1313
OtherBaseBehaviourTests,
1414
)
15-
from tests.utils import assert_frame_equal_dispatch
15+
from tests.utils import assert_frame_equal_dispatch, dataframe_init_dispatch
1616
from tubular.mapping import BaseMappingTransformMixin
1717

1818
# Note there are no tests that need inheriting from this file as the only difference is an expected transform output
@@ -21,8 +21,8 @@
2121
@pytest.fixture()
2222
def mapping():
2323
return {
24-
"a": {1: "a", 2: "b", 3: "c", 4: "d", 5: "e", 6: "f"},
25-
"b": {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6},
24+
"a": {1: "a", 2: "b", 3: "c", 4: "d", 5: "e", 6: "f", 7: "g", 8: "h", 9: None},
25+
"b": {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, None: 9},
2626
}
2727

2828

@@ -51,100 +51,133 @@ class TestTransform(GenericTransformTests):
5151
def setup_class(cls):
5252
cls.transformer_name = "BaseMappingTransformMixin"
5353

54-
def test_expected_output(self, mapping):
54+
@pytest.mark.parametrize("library", ["pandas", "polars"])
55+
def test_expected_output(self, mapping, library):
5556
"""Test that X is returned from transform."""
5657

57-
df = d.create_df_1()
58+
df = d.create_df_1(library=library)
5859

59-
expected = pd.DataFrame(
60-
{
61-
"a": ["a", "b", "c", "d", "e", "f"],
62-
"b": [1, 2, 3, 4, 5, 6],
63-
},
60+
expected_dict = {
61+
"a": ["a", "b", "c", "d", "e", "f"],
62+
"b": [1, 2, 3, 4, 5, 6],
63+
}
64+
65+
expected = dataframe_init_dispatch(
66+
dataframe_dict=expected_dict,
67+
library=library,
6468
)
6569

66-
x = BaseMappingTransformMixin(columns=["a", "b"])
70+
transformer = BaseMappingTransformMixin(columns=["a", "b"])
6771

68-
x.mappings = mapping
72+
# if transformer is not yet polars compatible, skip this test
73+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
74+
return
6975

70-
df_transformed = x.transform(df)
76+
transformer.mappings = mapping
77+
transformer.return_dtypes = {"a": "String", "b": "Int64"}
7178

72-
ta.equality.assert_equal_dispatch(
73-
expected=expected,
74-
actual=df_transformed,
75-
msg="BaseMappingTransformMixin from transform",
76-
)
79+
df_transformed = transformer.transform(df)
80+
81+
assert_frame_equal_dispatch(expected, df_transformed)
7782

78-
def test_mappings_unchanged(self, mapping):
83+
@pytest.mark.parametrize("library", ["pandas", "polars"])
84+
def test_mappings_unchanged(self, mapping, library):
7985
"""Test that mappings is unchanged in transform."""
80-
df = d.create_df_1()
86+
df = d.create_df_1(library=library)
8187

82-
x = BaseMappingTransformMixin(columns=["a", "b"])
88+
transformer = BaseMappingTransformMixin(columns=["a", "b"])
8389

84-
x.mappings = mapping
90+
# if transformer is not yet polars compatible, skip this test
91+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
92+
return
8593

86-
x.transform(df)
94+
transformer.mappings = mapping
95+
transformer.return_dtypes = {
96+
"a": "String",
97+
"b": "Int64",
98+
}
8799

88-
ta.equality.assert_equal_dispatch(
89-
expected=mapping,
90-
actual=x.mappings,
91-
msg="BaseMappingTransformer.transform has changed self.mappings unexpectedly",
92-
)
100+
transformer.transform(df)
101+
102+
assert (
103+
mapping == transformer.mappings
104+
), f"BaseMappingTransformer.transform has changed self.mappings unexpectedly, expected {mapping} but got {transformer.mappings}"
93105

106+
@pytest.mark.parametrize("library", ["pandas", "polars"])
94107
@pytest.mark.parametrize("non_df", [1, True, "a", [1, 2], {"a": 1}, None])
95108
def test_non_pd_type_error(
96109
self,
97110
non_df,
98111
mapping,
112+
library,
99113
):
100114
"""Test that an error is raised in transform is X is not a pd.DataFrame."""
101115

102-
df = d.create_df_10()
116+
df = d.create_df_10(library=library)
103117

104-
x = BaseMappingTransformMixin(columns=["a"])
118+
transformer = BaseMappingTransformMixin(columns=["a"])
105119

106-
x.mappings = mapping
120+
# if transformer is not yet polars compatible, skip this test
121+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
122+
return
107123

108-
x_fitted = x.fit(df, df["c"])
124+
transformer.mappings = mapping
125+
transformer.return_dtypes = {
126+
"a": "String",
127+
}
128+
129+
x_fitted = transformer.fit(df, df["c"])
109130

110131
with pytest.raises(
111132
TypeError,
112133
match="BaseMappingTransformMixin: X should be a polars or pandas DataFrame/LazyFrame",
113134
):
114135
x_fitted.transform(X=non_df)
115136

116-
def test_no_rows_error(self, mapping):
137+
@pytest.mark.parametrize("library", ["pandas", "polars"])
138+
def test_no_rows_error(self, mapping, library):
117139
"""Test an error is raised if X has no rows."""
118-
df = d.create_df_10()
140+
df = d.create_df_10(library=library)
141+
142+
transformer = BaseMappingTransformMixin(columns=["a"])
119143

120-
x = BaseMappingTransformMixin(columns=["a"])
144+
# if transformer is not yet polars compatible, skip this test
145+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
146+
return
121147

122-
x.mappings = mapping
148+
transformer.mappings = mapping
149+
transformer.return_dtypes = {"a": "String"}
123150

124-
x = x.fit(df, df["c"])
151+
transformer = transformer.fit(df, df["c"])
125152

126153
df = pd.DataFrame(columns=["a", "b", "c"])
127154

128155
with pytest.raises(
129156
ValueError,
130157
match=re.escape("BaseMappingTransformMixin: X has no rows; (0, 3)"),
131158
):
132-
x.transform(df)
159+
transformer.transform(df)
133160

134-
def test_original_df_not_updated(self, mapping):
161+
@pytest.mark.parametrize("library", ["pandas", "polars"])
162+
def test_original_df_not_updated(self, mapping, library):
135163
"""Test that the original dataframe is not transformed when transform method used."""
136164

137-
df = d.create_df_10()
165+
df = d.create_df_10(library=library)
166+
167+
transformer = BaseMappingTransformMixin(columns=["a"])
138168

139-
x = BaseMappingTransformMixin(columns=["a"])
169+
# if transformer is not yet polars compatible, skip this test
170+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
171+
return
140172

141-
x.mappings = mapping
173+
transformer.mappings = mapping
174+
transformer.return_dtypes = {"a": "String", "b": "Int64"}
142175

143-
x = x.fit(df, df["c"])
176+
transformer = transformer.fit(df, df["c"])
144177

145-
_ = x.transform(df)
178+
_ = transformer.transform(df)
146179

147-
pd.testing.assert_frame_equal(df, d.create_df_10())
180+
assert_frame_equal_dispatch(df, d.create_df_10(library=library))
148181

149182
@pytest.mark.parametrize(
150183
"minimal_dataframe_lookup",
@@ -160,17 +193,23 @@ def test_pandas_index_not_updated(
160193
"""Test that the original (pandas) dataframe index is not transformed when transform method used."""
161194

162195
df = minimal_dataframe_lookup[self.transformer_name]
163-
x = initialized_transformers[self.transformer_name]
164-
x.mappings = mapping
196+
transformer = initialized_transformers[self.transformer_name]
197+
198+
# if transformer is not yet polars compatible, skip this test
199+
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
200+
return
201+
202+
transformer.mappings = mapping
203+
transformer.return_dtypes = {"a": "String", "b": "String"}
165204

166205
# update to abnormal index
167206
df.index = [2 * i for i in df.index]
168207

169208
original_df = copy.deepcopy(df)
170209

171-
x = x.fit(df, df["a"])
210+
transformer = transformer.fit(df, df["a"])
172211

173-
_ = x.transform(df)
212+
_ = transformer.transform(df)
174213

175214
assert_frame_equal_dispatch(df, original_df)
176215

0 commit comments

Comments
 (0)