Skip to content

Commit 9308197

Browse files
authored
Merge pull request #388 from azukds/feature/ohe_values_2
Optional wanted_values feature added to OHE
2 parents 6fc1124 + abe7d63 commit 9308197

File tree

3 files changed

+283
-10
lines changed

3 files changed

+283
-10
lines changed

CHANGELOG.rst

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ functionality, as this is more complicated when transform is opinionated on type
4646
- narwhalified GroupRareLevelsTransformer. As part of this, had to make transformer more opinionated
4747
and refuse columns with nulls (raises an error directing to imputers.) `#372 <https://github.com/lvgig/tubular/issues/372>_`
4848
- narwhalified BaseDatetimeTransformer `#375 <https://github.com/azukds/tubular/issues/375>`
49+
- Optional wanted_levels feature has been integrated into the OneHotEncodingTransformer which allows users to specify which levels in a column they wish to encode. `#384 <https://github.com/azukds/tubular/issues/384>_`
50+
- Created unit tests to check if the values provided for wanted_values are as expected and if the output is as expected.
4951
- placeholder
5052
- placeholder
5153
- placeholder

tests/nominal/test_OneHotEncodingTransformer.py

+198
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,78 @@ class TestInit(
2626
def setup_class(cls):
2727
cls.transformer_name = "OneHotEncodingTransformer"
2828

29+
# Tests for wanted_values parameter
30+
31+
@pytest.mark.parametrize(
32+
"values",
33+
["a", ["a", "b"], 123, True],
34+
)
35+
def test_wanted_values_is_dict(self, values, minimal_attribute_dict):
36+
args = minimal_attribute_dict[self.transformer_name]
37+
args["wanted_values"] = values
38+
39+
with pytest.raises(
40+
TypeError,
41+
match="OneHotEncodingTransformer: wanted_values should be a dictionary",
42+
):
43+
OneHotEncodingTransformer(**args)
44+
45+
@pytest.mark.parametrize(
46+
"values",
47+
[
48+
{1: ["a", "b"]},
49+
{True: ["a"]},
50+
{("a",): ["b", "c"]},
51+
],
52+
)
53+
def test_wanted_values_key_is_str(self, values, minimal_attribute_dict):
54+
args = minimal_attribute_dict[self.transformer_name]
55+
args["wanted_values"] = values
56+
57+
with pytest.raises(
58+
TypeError,
59+
match="OneHotEncodingTransformer: Key in 'wanted_values' should be a string",
60+
):
61+
OneHotEncodingTransformer(**args)
62+
63+
@pytest.mark.parametrize(
64+
"values",
65+
[
66+
{"a": "b"},
67+
{"a": ("a", "b")},
68+
{"a": True},
69+
{"a": 123},
70+
],
71+
)
72+
def test_wanted_values_value_is_list(self, values, minimal_attribute_dict):
73+
args = minimal_attribute_dict[self.transformer_name]
74+
args["wanted_values"] = values
75+
76+
with pytest.raises(
77+
TypeError,
78+
match="OneHotEncodingTransformer: Values in the 'wanted_values' dictionary should be a list",
79+
):
80+
OneHotEncodingTransformer(**args)
81+
82+
@pytest.mark.parametrize(
83+
"values",
84+
[
85+
{"a": ["b", 123]},
86+
{"a": ["b", True]},
87+
{"a": ["b", None]},
88+
{"a": ["b", ["a", "b"]]},
89+
],
90+
)
91+
def test_wanted_values_entries_are_str(self, values, minimal_attribute_dict):
92+
args = minimal_attribute_dict[self.transformer_name]
93+
args["wanted_values"] = values
94+
95+
with pytest.raises(
96+
TypeError,
97+
match="OneHotEncodingTransformer: Entries in 'wanted_values' list should be a string",
98+
):
99+
OneHotEncodingTransformer(**args)
100+
29101

30102
class TestFit(GenericFitTests):
31103
"""Generic tests for transformer.fit()"""
@@ -50,6 +122,27 @@ def test_nulls_in_X_error(self, library):
50122
):
51123
transformer.fit(df)
52124

125+
@pytest.mark.parametrize(
126+
"library",
127+
["pandas", "polars"],
128+
)
129+
def test_fit_missing_levels_warning(self, library):
130+
"""Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
131+
df = d.create_df_1(library=library)
132+
133+
transformer = OneHotEncodingTransformer(
134+
columns=["b"],
135+
wanted_values={"b": ["f", "g"]},
136+
)
137+
138+
with pytest.warns(
139+
UserWarning,
140+
match=(
141+
r"OneHotEncodingTransformer: column b includes user-specified values \['g'\] not found in the dataset"
142+
),
143+
):
144+
transformer.fit(df)
145+
53146
@pytest.mark.parametrize(
54147
"library",
55148
["pandas", "polars"],
@@ -68,6 +161,24 @@ def test_fields_with_over_100_levels_error(self, library):
68161
):
69162
transformer.fit(df)
70163

164+
@pytest.mark.parametrize(
165+
"library",
166+
["pandas", "polars"],
167+
)
168+
def test_fit_no_warning_if_all_wanted_values_present(self, library, recwarn):
169+
"""Test that OneHotEncodingTransformer.fit does NOT raise a warning when all levels in wanted_levels are present in the data."""
170+
df = d.create_df_1(library=library)
171+
172+
transformer = OneHotEncodingTransformer(
173+
columns=["b"],
174+
wanted_values={"b": ["a", "b", "c", "d", "e", "f"]},
175+
)
176+
177+
transformer.fit(df)
178+
assert (
179+
len(recwarn) == 0
180+
), "OneHotEncodingTransformer.fit is raising unexpected warnings"
181+
71182

72183
class TestTransform(
73184
DropOriginalTransformMixinTests,
@@ -276,6 +387,28 @@ def test_warning_generated_by_unseen_categories(self, library):
276387
with pytest.warns(UserWarning, match="unseen categories"):
277388
transformer.transform(df_test)
278389

390+
@pytest.mark.parametrize(
391+
"library",
392+
["pandas", "polars"],
393+
)
394+
def test_transform_missing_levels_warning(self, library):
395+
"""Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
396+
df_train = d.create_df_7(library=library)
397+
df_test = d.create_df_8(library=library)
398+
399+
transformer = OneHotEncodingTransformer(
400+
columns=["b"],
401+
wanted_values={"b": ["v", "x", "z"]},
402+
)
403+
404+
transformer.fit(df_train)
405+
406+
with pytest.warns(
407+
UserWarning,
408+
match=r"OneHotEncodingTransformer: column b includes user-specified values \['v'\] not found in the dataset",
409+
):
410+
transformer.transform(df_test)
411+
279412
@pytest.mark.parametrize(
280413
"library",
281414
["pandas", "polars"],
@@ -319,3 +452,68 @@ def test_unseen_categories_encoded_as_all_zeroes(self, library):
319452
df_transformed_row[column_order],
320453
df_expected_row,
321454
)
455+
456+
@pytest.mark.parametrize(
457+
"library",
458+
["pandas", "polars"],
459+
)
460+
def test_transform_output_with_wanted_values_arg(self, library):
461+
"""
462+
Test to verify OneHotEncodingTransformer.transform zero-filled levels from user-specified "wanted_levels" and encodes only those listed in "wanted_levels".
463+
464+
"""
465+
df_train = d.create_df_7(library=library)
466+
df_test = d.create_df_8(library=library)
467+
468+
transformer = OneHotEncodingTransformer(
469+
columns=["b"],
470+
wanted_values={"b": ["v", "x", "z"]},
471+
)
472+
473+
transformer.fit(df_train)
474+
df_transformed = transformer.transform(df_test)
475+
476+
expected_df_dict = {
477+
"a": [1, 5, 2, 3, 3],
478+
"b": ["w", "w", "z", "y", "x"],
479+
"c": ["a", "a", "c", "b", "a"],
480+
"b_v": [0] * 5,
481+
"b_x": [0, 0, 0, 0, 1],
482+
"b_z": [0, 0, 1, 0, 0],
483+
}
484+
expected_df = dataframe_init_dispatch(
485+
library=library,
486+
dataframe_dict=expected_df_dict,
487+
)
488+
expected_df = nw.from_native(expected_df)
489+
# cast the columns
490+
boolean_cols = ["b_v", "b_x", "b_z"]
491+
for col_name in boolean_cols:
492+
expected_df = expected_df.with_columns(
493+
nw.col(col_name).cast(nw.Boolean),
494+
)
495+
expected_df = expected_df.with_columns(
496+
nw.col("c").cast(nw.Categorical),
497+
)
498+
499+
assert_frame_equal_dispatch(df_transformed, expected_df.to_native())
500+
501+
@pytest.mark.parametrize(
502+
"library",
503+
["pandas", "polars"],
504+
)
505+
def test_transform_no_warning_if_all_wanted_values_present(self, library, recwarn):
506+
"""Test that OneHotEncodingTransformer.transform does NOT raise a warning when all levels in wanted_levels are present in the data."""
507+
df_train = d.create_df_8(library=library)
508+
df_test = d.create_df_7(library=library)
509+
510+
transformer = OneHotEncodingTransformer(
511+
columns=["b"],
512+
wanted_values={"b": ["z", "y", "x"]},
513+
)
514+
transformer.fit(df_train)
515+
transformer.transform(df_test)
516+
517+
assert (
518+
len(recwarn) == 0
519+
), "OneHotEncodingTransformer.transform is raising unexpected warnings"

tubular/nominal.py

+83-10
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,9 @@ class OneHotEncodingTransformer(
11341134
Names of columns to transform. If the default of None is supplied all object and category
11351135
columns in X are used.
11361136
1137+
wanted_values: dict[str, list[str] or None , default = None
1138+
Optional parameter to select specific column levels to be transformed. If it is None, all levels in the categorical column will be encoded. It will take the format {col1: [level_1, level_2, ...]}.
1139+
11371140
separator : str
11381141
Used to create dummy column names, the name will take
11391142
the format [categorical feature][separator][category level]
@@ -1170,6 +1173,7 @@ class attribute, indicates whether transformer has been converted to polars/pand
11701173
def __init__(
11711174
self,
11721175
columns: str | list[str] | None = None,
1176+
wanted_values: dict[str, list[str]] | None = None,
11731177
separator: str = "_",
11741178
drop_original: bool = False,
11751179
copy: bool | None = None,
@@ -1184,6 +1188,29 @@ def __init__(
11841188
**kwargs,
11851189
)
11861190

1191+
if wanted_values is not None:
1192+
if not isinstance(wanted_values, dict):
1193+
msg = f"{self.classname()}: wanted_values should be a dictionary"
1194+
raise TypeError(msg)
1195+
1196+
for key, val_list in wanted_values.items():
1197+
# check key is a string
1198+
if not isinstance(key, str):
1199+
msg = f"{self.classname()}: Key in 'wanted_values' should be a string"
1200+
raise TypeError(msg)
1201+
1202+
# check value is a list
1203+
if not isinstance(val_list, list):
1204+
msg = f"{self.classname()}: Values in the 'wanted_values' dictionary should be a list"
1205+
raise TypeError(msg)
1206+
1207+
# check if each value within the list is a string
1208+
for val in val_list:
1209+
if not isinstance(val, str):
1210+
msg = f"{self.classname()}: Entries in 'wanted_values' list should be a string"
1211+
raise TypeError(msg)
1212+
1213+
self.wanted_values = wanted_values
11871214
self.set_drop_original_column(drop_original)
11881215
self.check_and_set_separator_column(separator)
11891216

@@ -1214,6 +1241,7 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
12141241
self.categories_ = {}
12151242
self.new_feature_names_ = {}
12161243
# Check each field has less than 100 categories/levels
1244+
missing_levels = {}
12171245
for c in self.columns:
12181246
levels = X.select(nw.col(c).unique())
12191247

@@ -1231,12 +1259,60 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
12311259
# for consistency
12321260
levels_list.sort()
12331261

1234-
self.categories_[c] = levels_list
1262+
# categories if 'values' is provided
1263+
selected_values = (
1264+
self.wanted_values.get(c, None) if self.wanted_values else None
1265+
)
1266+
1267+
if selected_values is None:
1268+
final_categories = levels_list
1269+
else:
1270+
final_categories = selected_values
12351271

1272+
self.categories_[c] = final_categories
12361273
self.new_feature_names_[c] = self._get_feature_names(column=c)
12371274

1275+
present_levels = set(X.get_column(c).unique().to_list())
1276+
missing_levels = self._warn_missing_levels(
1277+
present_levels,
1278+
c,
1279+
missing_levels,
1280+
)
1281+
12381282
return self
12391283

1284+
def _warn_missing_levels(
1285+
self,
1286+
present_levels: list,
1287+
c: str,
1288+
missing_levels: dict[str, list[str]],
1289+
) -> dict[str, list[str]]:
1290+
"""Logs a warning for user-specifed levels that are not found in the dataset and updates "missing_levels[c]" with those missing levels.
1291+
1292+
Parameters
1293+
----------
1294+
present_levels: list
1295+
List of levels observed in the data.
1296+
c: str
1297+
The column name being checked for missing user-specified levels.
1298+
missing_levels: dict[str, list[str]]
1299+
Dictionary containing missing user-specified levels for each column.
1300+
Returns
1301+
-------
1302+
missing_levels : dict[str, list[str]]
1303+
Dictionary updated to reflect new missing levels for column c
1304+
1305+
"""
1306+
# print warning for missing levels
1307+
missing_levels[c] = list(
1308+
set(self.categories_[c]).difference(present_levels),
1309+
)
1310+
if len(missing_levels[c]) > 0:
1311+
warning_msg = f"{self.classname()}: column {c} includes user-specified values {missing_levels[c]} not found in the dataset"
1312+
warnings.warn(warning_msg, UserWarning, stacklevel=2)
1313+
1314+
return missing_levels
1315+
12401316
def _get_feature_names(
12411317
self,
12421318
column: str,
@@ -1287,17 +1363,14 @@ def transform(self, X: FrameT) -> FrameT:
12871363
)
12881364

12891365
# print warning for unseen levels
1290-
present_levels = set(X.select(nw.col(c).unique()).get_column(c).to_list())
1366+
present_levels = set(X.get_column(c).unique().to_list())
12911367
unseen_levels = present_levels.difference(set(self.categories_[c]))
1292-
missing_levels[c] = list(
1293-
set(self.categories_[c]).difference(present_levels),
1294-
)
12951368
if len(unseen_levels) > 0:
1296-
warnings.warn(
1297-
f"{self.classname()}: column {c} has unseen categories: {unseen_levels}",
1298-
UserWarning,
1299-
stacklevel=2,
1300-
)
1369+
warning_msg = f"{self.classname()}: column {c} has unseen categories: {unseen_levels}"
1370+
warnings.warn(warning_msg, UserWarning, stacklevel=2)
1371+
1372+
# print warning for missing levels
1373+
self._warn_missing_levels(present_levels, c, missing_levels)
13011374

13021375
dummies = X.get_column(c).to_dummies(separator=self.separator)
13031376

0 commit comments

Comments
 (0)