@@ -26,51 +26,48 @@ class TestInit(
26
26
def setup_class (cls ):
27
27
cls .transformer_name = "OneHotEncodingTransformer"
28
28
29
-
30
29
# Tests for wanted_values parameter
31
30
32
31
@pytest .mark .parametrize (
33
- "values" ,
34
- [ "a" , ["a" , "b" ], 123 , True ],
32
+ "values" ,
33
+ [ "a" , ["a" , "b" ], 123 , True ],
35
34
)
36
35
def test_wanted_values_is_dict (self , values , minimal_attribute_dict ):
37
36
args = minimal_attribute_dict [self .transformer_name ]
38
- args ["wanted_values" ]= values
39
-
37
+ args ["wanted_values" ] = values
38
+
40
39
with pytest .raises (
41
40
TypeError ,
42
- match = "OneHotEncodingTransformer: Wanted_values should be a dictionary" ,
41
+ match = "OneHotEncodingTransformer: Wanted_values should be a dictionary" ,
43
42
):
44
43
OneHotEncodingTransformer (** args )
45
44
46
-
47
45
@pytest .mark .parametrize (
48
- "values" ,
49
- [
50
- {1 :["a" , "b" ]},
51
- {True :["a" ]},
52
- {("a" ,):["b" , "c" ]},
53
- ]
46
+ "values" ,
47
+ [
48
+ {1 : ["a" , "b" ]},
49
+ {True : ["a" ]},
50
+ {("a" ,): ["b" , "c" ]},
51
+ ],
54
52
)
55
53
def test_wanted_values_key_is_str (self , values , minimal_attribute_dict ):
56
54
args = minimal_attribute_dict [self .transformer_name ]
57
- args ["wanted_values" ]= values
58
-
55
+ args ["wanted_values" ] = values
56
+
59
57
with pytest .raises (
60
58
TypeError ,
61
- match = "OneHotEncodingTransformer: Key in 'wanted_values' should be a string" ,
59
+ match = "OneHotEncodingTransformer: Key in 'wanted_values' should be a string" ,
62
60
):
63
61
OneHotEncodingTransformer (** args )
64
62
65
-
66
63
@pytest .mark .parametrize (
67
64
"values" ,
68
65
[
69
66
{"a" : "b" },
70
- {"a" :("a" ,"b" )},
67
+ {"a" : ("a" , "b" )},
71
68
{"a" : True },
72
69
{"a" : 123 },
73
- ]
70
+ ],
74
71
)
75
72
def test_wanted_values_value_is_list (self , values , minimal_attribute_dict ):
76
73
args = minimal_attribute_dict [self .transformer_name ]
@@ -82,26 +79,24 @@ def test_wanted_values_value_is_list(self, values, minimal_attribute_dict):
82
79
):
83
80
OneHotEncodingTransformer (** args )
84
81
85
-
86
82
@pytest .mark .parametrize (
87
83
"values" ,
88
84
[
89
85
{"a" : ["b" , 123 ]},
90
86
{"a" : ["b" , True ]},
91
87
{"a" : ["b" , None ]},
92
88
{"a" : ["b" , ["a" , "b" ]]},
93
- ]
89
+ ],
94
90
)
95
91
def test_wanted_values_entries_are_str (self , values , minimal_attribute_dict ):
96
- args = minimal_attribute_dict [self .transformer_name ]
97
- args ["wanted_values" ]= values
92
+ args = minimal_attribute_dict [self .transformer_name ]
93
+ args ["wanted_values" ] = values
98
94
99
95
with pytest .raises (
100
96
TypeError ,
101
- match = "OneHotEncodingTransformer: Entries in 'wanted_values' list should be a string"
97
+ match = "OneHotEncodingTransformer: Entries in 'wanted_values' list should be a string" ,
102
98
):
103
99
OneHotEncodingTransformer (** args )
104
-
105
100
106
101
107
102
class TestFit (GenericFitTests ):
@@ -127,6 +122,23 @@ def test_nulls_in_X_error(self, library):
127
122
):
128
123
transformer .fit (df )
129
124
125
+ @pytest .mark .parametrize (
126
+ "library" ,
127
+ ["pandas" , "polars" ]
128
+ )
129
+ def test_fit_missing_levels_warning (self , library ):
130
+ """ Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
131
+ df = d .create_df_1 (library = library )
132
+
133
+ transformer = OneHotEncodingTransformer (columns = ["b" ], wanted_values = {"b" : ["f" , "g" ]})
134
+
135
+ with pytest .warns (
136
+ UserWarning ,
137
+ match = ("OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset" ),
138
+ ):
139
+ transformer .fit (df )
140
+
141
+
130
142
@pytest .mark .parametrize (
131
143
"library" ,
132
144
["pandas" , "polars" ],
@@ -353,6 +365,25 @@ def test_warning_generated_by_unseen_categories(self, library):
353
365
with pytest .warns (UserWarning , match = "unseen categories" ):
354
366
transformer .transform (df_test )
355
367
368
+ @pytest .mark .parametrize (
369
+ "library" ,
370
+ ["pandas" , "polars" ]
371
+ )
372
+ def test_transform_missing_levels_warning (self , library ):
373
+ """ Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
374
+ df_train = d .create_df_7 (library = library )
375
+ df_test = d .create_df_8 (library = library )
376
+
377
+ transformer = OneHotEncodingTransformer (columns = ["b" ], wanted_values = {"b" : ["v" , "x" , "z" ]})
378
+
379
+ transformer .fit (df_train )
380
+
381
+ with pytest .warns (
382
+ UserWarning ,
383
+ match = "OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"
384
+ ):
385
+ transformer .transform (df_test )
386
+
356
387
@pytest .mark .parametrize (
357
388
"library" ,
358
389
["pandas" , "polars" ],
0 commit comments