Skip to content

Commit 10f8f0c

Browse files
committed
Init testing
1 parent fb0c6f5 commit 10f8f0c

File tree

2 files changed

+60
-27
lines changed

2 files changed

+60
-27
lines changed

tests/nominal/test_OneHotEncodingTransformer.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -26,51 +26,48 @@ class TestInit(
2626
def setup_class(cls):
2727
cls.transformer_name = "OneHotEncodingTransformer"
2828

29-
3029
# Tests for wanted_values parameter
3130

3231
@pytest.mark.parametrize(
33-
"values",
34-
[ "a", ["a", "b"], 123, True],
32+
"values",
33+
["a", ["a", "b"], 123, True],
3534
)
3635
def test_wanted_values_is_dict(self, values, minimal_attribute_dict):
3736
args = minimal_attribute_dict[self.transformer_name]
38-
args["wanted_values"]=values
39-
37+
args["wanted_values"] = values
38+
4039
with pytest.raises(
4140
TypeError,
42-
match= "OneHotEncodingTransformer: Wanted_values should be a dictionary",
41+
match="OneHotEncodingTransformer: Wanted_values should be a dictionary",
4342
):
4443
OneHotEncodingTransformer(**args)
4544

46-
4745
@pytest.mark.parametrize(
48-
"values",
49-
[
50-
{1:["a", "b"]},
51-
{True:["a"]},
52-
{("a",):["b", "c"]},
53-
]
46+
"values",
47+
[
48+
{1: ["a", "b"]},
49+
{True: ["a"]},
50+
{("a",): ["b", "c"]},
51+
],
5452
)
5553
def test_wanted_values_key_is_str(self, values, minimal_attribute_dict):
5654
args = minimal_attribute_dict[self.transformer_name]
57-
args["wanted_values"]= values
58-
55+
args["wanted_values"] = values
56+
5957
with pytest.raises(
6058
TypeError,
61-
match= "OneHotEncodingTransformer: Key in 'wanted_values' should be a string",
59+
match="OneHotEncodingTransformer: Key in 'wanted_values' should be a string",
6260
):
6361
OneHotEncodingTransformer(**args)
6462

65-
6663
@pytest.mark.parametrize(
6764
"values",
6865
[
6966
{"a": "b"},
70-
{"a":("a","b")},
67+
{"a": ("a", "b")},
7168
{"a": True},
7269
{"a": 123},
73-
]
70+
],
7471
)
7572
def test_wanted_values_value_is_list(self, values, minimal_attribute_dict):
7673
args = minimal_attribute_dict[self.transformer_name]
@@ -82,26 +79,24 @@ def test_wanted_values_value_is_list(self, values, minimal_attribute_dict):
8279
):
8380
OneHotEncodingTransformer(**args)
8481

85-
8682
@pytest.mark.parametrize(
8783
"values",
8884
[
8985
{"a": ["b", 123]},
9086
{"a": ["b", True]},
9187
{"a": ["b", None]},
9288
{"a": ["b", ["a", "b"]]},
93-
]
89+
],
9490
)
9591
def test_wanted_values_entries_are_str(self, values, minimal_attribute_dict):
96-
args= minimal_attribute_dict[self.transformer_name]
97-
args["wanted_values"]= values
92+
args = minimal_attribute_dict[self.transformer_name]
93+
args["wanted_values"] = values
9894

9995
with pytest.raises(
10096
TypeError,
101-
match= "OneHotEncodingTransformer: Entries in 'wanted_values' list should be a string"
97+
match="OneHotEncodingTransformer: Entries in 'wanted_values' list should be a string",
10298
):
10399
OneHotEncodingTransformer(**args)
104-
105100

106101

107102
class TestFit(GenericFitTests):
@@ -127,6 +122,23 @@ def test_nulls_in_X_error(self, library):
127122
):
128123
transformer.fit(df)
129124

125+
@pytest.mark.parametrize(
126+
"library",
127+
["pandas", "polars"]
128+
)
129+
def test_fit_missing_levels_warning(self, library):
130+
""" Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
131+
df = d.create_df_1(library=library)
132+
133+
transformer = OneHotEncodingTransformer(columns=["b"], wanted_values={"b": ["f", "g"]})
134+
135+
with pytest.warns(
136+
UserWarning,
137+
match= ("OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"),
138+
):
139+
transformer.fit(df)
140+
141+
130142
@pytest.mark.parametrize(
131143
"library",
132144
["pandas", "polars"],
@@ -353,6 +365,25 @@ def test_warning_generated_by_unseen_categories(self, library):
353365
with pytest.warns(UserWarning, match="unseen categories"):
354366
transformer.transform(df_test)
355367

368+
@pytest.mark.parametrize(
369+
"library",
370+
["pandas", "polars"]
371+
)
372+
def test_transform_missing_levels_warning(self, library):
373+
""" Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
374+
df_train = d.create_df_7(library=library)
375+
df_test = d.create_df_8(library=library)
376+
377+
transformer = OneHotEncodingTransformer(columns=["b"], wanted_values={"b": ["v", "x", "z"]})
378+
379+
transformer.fit(df_train)
380+
381+
with pytest.warns(
382+
UserWarning,
383+
match="OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"
384+
):
385+
transformer.transform(df_test)
386+
356387
@pytest.mark.parametrize(
357388
"library",
358389
["pandas", "polars"],

tubular/nominal.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,9 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
12271227

12281228
present_levels = set(X.select(nw.col(c).unique()).get_column(c).to_list())
12291229
missing_levels = self._warn_missing_levels(
1230-
present_levels, c, missing_levels
1230+
present_levels,
1231+
c,
1232+
missing_levels,
12311233
)
12321234

12331235
return self
@@ -1242,7 +1244,7 @@ def _warn_missing_levels(
12421244
missing_levels[c] = list(
12431245
set(self.categories_[c]).difference(present_levels),
12441246
)
1245-
if missing_levels:
1247+
if len(missing_levels) > 0:
12461248
warning_msg = f"{self.classname()}: column {c} includes user-specified values {missing_levels[c]} not found in the dataset"
12471249
warnings.warn(warning_msg, UserWarning, stacklevel=2)
12481250

0 commit comments

Comments
 (0)