@@ -1088,7 +1088,7 @@ class OneHotEncodingTransformer(
1088
1088
columns in X are used.
1089
1089
1090
1090
values: list of strings or None, default = None
1091
- Optional parameter to select specific columns to be transformed. If it is None, all categorical columns will be encoded.
1091
+ Optional parameter to select specific column levels to be transformed. If it is None, all levels in the categorical column will be encoded.
1092
1092
1093
1093
separator : str
1094
1094
Used to create dummy column names, the name will take
@@ -1126,7 +1126,7 @@ class attribute, indicates whether transformer has been converted to polars/pand
1126
1126
def __init__ (
1127
1127
self ,
1128
1128
columns : str | list [str ] | None = None ,
1129
- values : list [str ] | None = None ,
1129
+ values : dict [ str , list [str ] ] | None = None ,
1130
1130
separator : str = "_" ,
1131
1131
drop_original : bool = False ,
1132
1132
copy : bool | None = None ,
@@ -1189,25 +1189,23 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
1189
1189
# for consistency
1190
1190
levels_list .sort ()
1191
1191
1192
- self .categories_ [c ] = levels_list
1192
+ # filter if 'values' is provided
1193
+ if self .values is not None :
1194
+ selected_values = self .values .get (c , None )
1193
1195
1194
- self .new_feature_names_ [c ] = self ._get_feature_names (column = c )
1196
+ if selected_values is not None :
1197
+ levels_list = [
1198
+ level for level in levels_list if level in selected_values
1199
+ ]
1195
1200
1196
- # filter categories if values is provided
1197
- if self .values is not None :
1198
- self .categories_ = {
1199
- c : self .categories_ [c ] for c in self .values if c in self .categories_
1200
- }
1201
- self .new_feature_names_ = {
1202
- c : self .new_feature_names_ [c ]
1203
- for c in self .values
1204
- if c in self .new_feature_names_
1205
- }
1201
+ if levels_list :
1202
+ self .categories_ [c ] = levels_list
1203
+
1204
+ self .new_feature_names_ [c ] = self ._get_feature_names (column = c )
1206
1205
1207
- # checks if column in 'values' exist in categories
1208
- if not self .categories_ :
1209
- error_message = "No valid columns in 'values' for encoding"
1210
- raise ValueError (error_message )
1206
+ if not self .categories_ :
1207
+ error_message = "No valid categories found in 'values' for encoding"
1208
+ raise ValueError (error_message )
1211
1209
1212
1210
return self
1213
1211
0 commit comments