Skip to content

Commit eed7eb1

Browse files
committed
No warning test case added
1 parent 27e1094 commit eed7eb1

File tree

3 files changed

+78
-24
lines changed

3 files changed

+78
-24
lines changed

CHANGELOG.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@ Changed
3838
- narwhalified DropOriginalMixin `#352 <https://github.com/lvgig/tubular/issues/352>_`
3939
- narwhalified BaseMappingTransformer `#367 <https://github.com/lvgig/tubular/issues/367>_`
4040
- narwhalified BaseDatetimeTransformer `#375 <https://github.com/azukds/tubular/issues/375>`
41+
- Optional wanted_levels feature has been integrated into the OneHotEncodingTransformer which allows users to specify which levels in a column they wish to encode. `#384 <https://github.com/azukds/tubular/issues/384>_`
42+
- Created unit tests to check if the values provided for wanted_values are as expected and if the output is as expected.
4143
- placeholder
4244
- placeholder
43-
- placeholder
45+
46+
4447

4548
1.4.1 (02/12/2024)
4649
------------------

tests/nominal/test_OneHotEncodingTransformer.py

+57-18
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_wanted_values_is_dict(self, values, minimal_attribute_dict):
3838

3939
with pytest.raises(
4040
TypeError,
41-
match="OneHotEncodingTransformer: Wanted_values should be a dictionary",
41+
match="OneHotEncodingTransformer: wanted_values should be a dictionary",
4242
):
4343
OneHotEncodingTransformer(**args)
4444

@@ -131,7 +131,8 @@ def test_fit_missing_levels_warning(self, library):
131131
df = d.create_df_1(library=library)
132132

133133
transformer = OneHotEncodingTransformer(
134-
columns=["b"], wanted_values={"b": ["f", "g"]}
134+
columns=["b"],
135+
wanted_values={"b": ["f", "g"]},
135136
)
136137

137138
with pytest.warns(
@@ -160,6 +161,21 @@ def test_fields_with_over_100_levels_error(self, library):
160161
):
161162
transformer.fit(df)
162163

164+
@pytest.mark.parametrize(
165+
"library",
166+
["pandas", "polars"],
167+
)
168+
def test_fit_no_warning_if_all_wanted_values_present(self, library):
169+
"""Test that OneHotEncodingTransformer.fit does NOT raise a warning when all levels in wanted_levels are present in the data."""
170+
df = d.create_df_1(library=library)
171+
172+
transformer = OneHotEncodingTransformer(
173+
columns=["b"], wanted_values={"b": ["a", "b", "c", "d", "e", "f"]}
174+
)
175+
176+
with pytest.warns(None):
177+
transformer.fit(df)
178+
163179

164180
class TestTransform(
165181
DropOriginalTransformMixinTests,
@@ -378,7 +394,8 @@ def test_transform_missing_levels_warning(self, library):
378394
df_test = d.create_df_8(library=library)
379395

380396
transformer = OneHotEncodingTransformer(
381-
columns=["b"], wanted_values={"b": ["v", "x", "z"]}
397+
columns=["b"],
398+
wanted_values={"b": ["v", "x", "z"]},
382399
)
383400

384401
transformer.fit(df_train)
@@ -433,42 +450,64 @@ def test_unseen_categories_encoded_as_all_zeroes(self, library):
433450
df_expected_row,
434451
)
435452

436-
437453
@pytest.mark.parametrize(
438454
"library",
439455
["pandas", "polars"],
440456
)
441-
def test_transform_missing_levels_encoded_as_all_zeroes(self, library):
442-
"""Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
457+
def test_transform_output_with_wanted_values_arg(self, library):
458+
"""
459+
Test to verify OneHotEncodingTransformer.transform zero-filled levels from user-specified "wanted_levels" and encodes only those listed in "wanted_levels".
460+
461+
"""
443462
df_train = d.create_df_7(library=library)
444463
df_test = d.create_df_8(library=library)
445464

446465
transformer = OneHotEncodingTransformer(
447-
columns=["b"], wanted_values={"b": ["v", "x", "z"]}
466+
columns=["b"],
467+
wanted_values={"b": ["v", "x", "z"]},
448468
)
449469

450470
transformer.fit(df_train)
451471
df_transformed = transformer.transform(df_test)
452472

453-
expected_df_dict= {
473+
expected_df_dict = {
454474
"a": [1, 5, 2, 3, 3],
455475
"b": ["w", "w", "z", "y", "x"],
456476
"c": ["a", "a", "c", "b", "a"],
457-
"b_v": [0]*5,
458-
"b_x": [0,0,0,0,1],
459-
"b_z":[0,0,1,0,0],
477+
"b_v": [0] * 5,
478+
"b_x": [0, 0, 0, 0, 1],
479+
"b_z": [0, 0, 1, 0, 0],
460480
}
461-
expected_df = dataframe_init_dispatch(library=library, dataframe_dict=expected_df_dict)
481+
expected_df = dataframe_init_dispatch(
482+
library=library,
483+
dataframe_dict=expected_df_dict,
484+
)
462485
expected_df = nw.from_native(expected_df)
463486
# cast the columns
464-
boolean_cols= ["b_v", "b_x", "b_z"]
487+
boolean_cols = ["b_v", "b_x", "b_z"]
465488
for col_name in boolean_cols:
466-
expected_df= expected_df.with_columns(
467-
nw.col(col_name).cast(nw.Boolean)
489+
expected_df = expected_df.with_columns(
490+
nw.col(col_name).cast(nw.Boolean),
468491
)
469-
expected_df= expected_df.with_columns(
470-
nw.col("c").cast(nw.Categorical)
492+
expected_df = expected_df.with_columns(
493+
nw.col("c").cast(nw.Categorical),
471494
)
472495

473496
assert_frame_equal_dispatch(df_transformed, expected_df.to_native())
474-
497+
498+
@pytest.mark.parametrize(
499+
"library",
500+
["pandas", "polars"],
501+
)
502+
def test_transform_no_warning_if_all_wanted_values_present(self, library):
503+
"""Test that OneHotEncodingTransformer.transform does NOT raise a warning when all levels in wanted_levels are present in the data."""
504+
df_train = d.create_df_7(library=library)
505+
df_test = d.create_df_8(library=library)
506+
507+
transformer = OneHotEncodingTransformer(
508+
columns=["b"], wanted_values={"b": ["x", "z", "y"]}
509+
)
510+
transformer.fit(df_train)
511+
512+
with pytest.warns(None):
513+
transformer.transform(df_test)

tubular/nominal.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def __init__(
11431143

11441144
if wanted_values is not None:
11451145
if not isinstance(wanted_values, dict):
1146-
msg = f"{self.classname()}: Wanted_values should be a dictionary"
1146+
msg = f"{self.classname()}: wanted_values should be a dictionary"
11471147
raise TypeError(msg)
11481148

11491149
for key, val_list in wanted_values.items():
@@ -1225,7 +1225,7 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
12251225
self.categories_[c] = final_categories
12261226
self.new_feature_names_[c] = self._get_feature_names(column=c)
12271227

1228-
present_levels = set(X.select(nw.col(c).unique()).get_column(c).to_list())
1228+
present_levels = set(X.get_column(c).unique().to_list())
12291229
missing_levels = self._warn_missing_levels(
12301230
present_levels,
12311231
c,
@@ -1239,12 +1239,24 @@ def _warn_missing_levels(
12391239
present_levels: list,
12401240
c: str,
12411241
missing_levels: dict[str, list[str]],
1242-
) -> list:
1242+
) -> dict[str, list[str]]:
1243+
"""Logs a warning for user-specifed levels that are not found in the dataset and updates "missing_levels[c]" with those missing levels.
1244+
1245+
Parameters
1246+
----------
1247+
present_levels: list
1248+
List of levels observed in the data.
1249+
c: str
1250+
The column name being checked for missing user-specified levels.
1251+
missing_levels: dict[str, list[str]]
1252+
Dictionary containing missing user-specified levels for each column.
1253+
1254+
"""
12431255
# print warning for missing levels
12441256
missing_levels[c] = list(
12451257
set(self.categories_[c]).difference(present_levels),
12461258
)
1247-
if len(missing_levels) > 0:
1259+
if len(missing_levels[c]) > 0:
12481260
warning_msg = f"{self.classname()}: column {c} includes user-specified values {missing_levels[c]} not found in the dataset"
12491261
warnings.warn(warning_msg, UserWarning, stacklevel=2)
12501262

@@ -1300,7 +1312,7 @@ def transform(self, X: FrameT) -> FrameT:
13001312
)
13011313

13021314
# print warning for unseen levels
1303-
present_levels = set(X.select(nw.col(c).unique()).get_column(c).to_list())
1315+
present_levels = set(X.get_column(c).unique().to_list())
13041316
unseen_levels = present_levels.difference(set(self.categories_[c]))
13051317
if len(unseen_levels) > 0:
13061318
warning_msg = f"{self.classname()}: column {c} has unseen categories: {unseen_levels}"

0 commit comments

Comments
 (0)