@@ -123,22 +123,25 @@ def test_nulls_in_X_error(self, library):
123
123
transformer .fit (df )
124
124
125
125
@pytest .mark .parametrize (
126
- "library" ,
127
- ["pandas" , "polars" ]
126
+ "library" ,
127
+ ["pandas" , "polars" ],
128
128
)
129
129
def test_fit_missing_levels_warning (self , library ):
130
- """ Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
130
+ """Test OneHotEncodingTransformer.fit triggers a warning for missing levels."""
131
131
df = d .create_df_1 (library = library )
132
132
133
- transformer = OneHotEncodingTransformer (columns = ["b" ], wanted_values = {"b" : ["f" , "g" ]})
133
+ transformer = OneHotEncodingTransformer (
134
+ columns = ["b" ], wanted_values = {"b" : ["f" , "g" ]}
135
+ )
134
136
135
137
with pytest .warns (
136
138
UserWarning ,
137
- match = ("OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset" ),
139
+ match = (
140
+ r"OneHotEncodingTransformer: column b includes user-specified values \['g'\] not found in the dataset"
141
+ ),
138
142
):
139
143
transformer .fit (df )
140
144
141
-
142
145
@pytest .mark .parametrize (
143
146
"library" ,
144
147
["pandas" , "polars" ],
@@ -366,22 +369,24 @@ def test_warning_generated_by_unseen_categories(self, library):
366
369
transformer .transform (df_test )
367
370
368
371
@pytest .mark .parametrize (
369
- "library" ,
370
- ["pandas" , "polars" ]
372
+ "library" ,
373
+ ["pandas" , "polars" ],
371
374
)
372
375
def test_transform_missing_levels_warning (self , library ):
373
- """ Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
376
+ """Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
374
377
df_train = d .create_df_7 (library = library )
375
378
df_test = d .create_df_8 (library = library )
376
379
377
- transformer = OneHotEncodingTransformer (columns = ["b" ], wanted_values = {"b" : ["v" , "x" , "z" ]})
380
+ transformer = OneHotEncodingTransformer (
381
+ columns = ["b" ], wanted_values = {"b" : ["v" , "x" , "z" ]}
382
+ )
378
383
379
384
transformer .fit (df_train )
380
385
381
386
with pytest .warns (
382
- UserWarning ,
383
- match = "OneHotEncodingTransformer: column b includes user-specified values .* not found in the dataset"
384
- ):
387
+ UserWarning ,
388
+ match = r "OneHotEncodingTransformer: column b includes user-specified values \['v'\] not found in the dataset",
389
+ ):
385
390
transformer .transform (df_test )
386
391
387
392
@pytest .mark .parametrize (
@@ -427,3 +432,43 @@ def test_unseen_categories_encoded_as_all_zeroes(self, library):
427
432
df_transformed_row [column_order ],
428
433
df_expected_row ,
429
434
)
435
+
436
+
437
+ @pytest .mark .parametrize (
438
+ "library" ,
439
+ ["pandas" , "polars" ],
440
+ )
441
+ def test_transform_missing_levels_encoded_as_all_zeroes (self , library ):
442
+ """Test OneHotEncodingTransformer.transform triggers a warning for missing levels."""
443
+ df_train = d .create_df_7 (library = library )
444
+ df_test = d .create_df_8 (library = library )
445
+
446
+ transformer = OneHotEncodingTransformer (
447
+ columns = ["b" ], wanted_values = {"b" : ["v" , "x" , "z" ]}
448
+ )
449
+
450
+ transformer .fit (df_train )
451
+ df_transformed = transformer .transform (df_test )
452
+
453
+ expected_df_dict = {
454
+ "a" : [1 , 5 , 2 , 3 , 3 ],
455
+ "b" : ["w" , "w" , "z" , "y" , "x" ],
456
+ "c" : ["a" , "a" , "c" , "b" , "a" ],
457
+ "b_v" : [0 ]* 5 ,
458
+ "b_x" : [0 ,0 ,0 ,0 ,1 ],
459
+ "b_z" :[0 ,0 ,1 ,0 ,0 ],
460
+ }
461
+ expected_df = dataframe_init_dispatch (library = library , dataframe_dict = expected_df_dict )
462
+ expected_df = nw .from_native (expected_df )
463
+ # cast the columns
464
+ boolean_cols = ["b_v" , "b_x" , "b_z" ]
465
+ for col_name in boolean_cols :
466
+ expected_df = expected_df .with_columns (
467
+ nw .col (col_name ).cast (nw .Boolean )
468
+ )
469
+ expected_df = expected_df .with_columns (
470
+ nw .col ("c" ).cast (nw .Categorical )
471
+ )
472
+
473
+ assert_frame_equal_dispatch (df_transformed , expected_df .to_native ())
474
+
0 commit comments