Skip to content

Commit 27e1094

Browse files
committed
Transform tests
1 parent 10f8f0c commit 27e1094

File tree

1 file changed

+58
-13
lines changed

1 file changed

+58
-13
lines changed

tests/nominal/test_OneHotEncodingTransformer.py

+58-13
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,25 @@ def test_nulls_in_X_error(self, library):
123123
transformer.fit(df)
124124

125125
@pytest.mark.parametrize(
126-
"library",
127-
["pandas", "polars"]
126+
"library",
127+
["pandas", "polars"],
128128
)
129129
def test_fit_missing_levels_warning(self, library):
130-
""" Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
130+
"""Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
131131
df = d.create_df_1(library=library)
132132

133-
transformer = OneHotEncodingTransformer(columns=["b"], wanted_values={"b": ["f", "g"]})
133+
transformer = OneHotEncodingTransformer(
134+
columns=["b"], wanted_values={"b": ["f", "g"]}
135+
)
134136

135137
with pytest.warns(
136138
UserWarning,
137-
match= ("OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"),
139+
match=(
140+
r"OneHotEncodingTransformer: column b includes user-specified values \['g'\] not found in the dataset"
141+
),
138142
):
139143
transformer.fit(df)
140144

141-
142145
@pytest.mark.parametrize(
143146
"library",
144147
["pandas", "polars"],
@@ -366,22 +369,24 @@ def test_warning_generated_by_unseen_categories(self, library):
366369
transformer.transform(df_test)
367370

368371
@pytest.mark.parametrize(
369-
"library",
370-
["pandas", "polars"]
372+
"library",
373+
["pandas", "polars"],
371374
)
372375
def test_transform_missing_levels_warning(self, library):
373-
""" Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
376+
"""Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
374377
df_train = d.create_df_7(library=library)
375378
df_test = d.create_df_8(library=library)
376379

377-
transformer = OneHotEncodingTransformer(columns=["b"], wanted_values={"b": ["v", "x", "z"]})
380+
transformer = OneHotEncodingTransformer(
381+
columns=["b"], wanted_values={"b": ["v", "x", "z"]}
382+
)
378383

379384
transformer.fit(df_train)
380385

381386
with pytest.warns(
382-
UserWarning,
383-
match="OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"
384-
):
387+
UserWarning,
388+
match=r"OneHotEncodingTransformer: column b includes user-specified values \['v'\] not found in the dataset",
389+
):
385390
transformer.transform(df_test)
386391

387392
@pytest.mark.parametrize(
@@ -427,3 +432,43 @@ def test_unseen_categories_encoded_as_all_zeroes(self, library):
427432
df_transformed_row[column_order],
428433
df_expected_row,
429434
)
435+
436+
437+
@pytest.mark.parametrize(
438+
"library",
439+
["pandas", "polars"],
440+
)
441+
def test_transform_missing_levels_encoded_as_all_zeroes(self, library):
442+
"""Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
443+
df_train = d.create_df_7(library=library)
444+
df_test = d.create_df_8(library=library)
445+
446+
transformer = OneHotEncodingTransformer(
447+
columns=["b"], wanted_values={"b": ["v", "x", "z"]}
448+
)
449+
450+
transformer.fit(df_train)
451+
df_transformed = transformer.transform(df_test)
452+
453+
expected_df_dict= {
454+
"a": [1, 5, 2, 3, 3],
455+
"b": ["w", "w", "z", "y", "x"],
456+
"c": ["a", "a", "c", "b", "a"],
457+
"b_v": [0]*5,
458+
"b_x": [0,0,0,0,1],
459+
"b_z":[0,0,1,0,0],
460+
}
461+
expected_df = dataframe_init_dispatch(library=library, dataframe_dict=expected_df_dict)
462+
expected_df = nw.from_native(expected_df)
463+
# cast the columns
464+
boolean_cols= ["b_v", "b_x", "b_z"]
465+
for col_name in boolean_cols:
466+
expected_df= expected_df.with_columns(
467+
nw.col(col_name).cast(nw.Boolean)
468+
)
469+
expected_df= expected_df.with_columns(
470+
nw.col("c").cast(nw.Categorical)
471+
)
472+
473+
assert_frame_equal_dispatch(df_transformed, expected_df.to_native())
474+

0 commit comments

Comments
 (0)