Skip to content

Commit 7237cc7

Browse files
committedFeb 20, 2025·
'Corrected 'values' feature'
1 parent 68c99f5 commit 7237cc7

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed
 

‎tubular/nominal.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ class OneHotEncodingTransformer(
10881088
columns in X are used.
10891089
10901090
values: list of strings or None, default = None
1091-
Optional parameter to select specific columns to be transformed. If it is None, all categorical columns will be encoded.
1091+
Optional parameter to select specific column levels to be transformed. If it is None, all levels in the categorical column will be encoded.
10921092
10931093
separator : str
10941094
Used to create dummy column names, the name will take
@@ -1126,7 +1126,7 @@ class attribute, indicates whether transformer has been converted to polars/pand
11261126
def __init__(
11271127
self,
11281128
columns: str | list[str] | None = None,
1129-
values: list[str] | None = None,
1129+
values: dict[str, list[str]] | None = None,
11301130
separator: str = "_",
11311131
drop_original: bool = False,
11321132
copy: bool | None = None,
@@ -1189,25 +1189,23 @@ def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
11891189
# for consistency
11901190
levels_list.sort()
11911191

1192-
self.categories_[c] = levels_list
1192+
# filter if 'values' is provided
1193+
if self.values is not None:
1194+
selected_values = self.values.get(c, None)
11931195

1194-
self.new_feature_names_[c] = self._get_feature_names(column=c)
1196+
if selected_values is not None:
1197+
levels_list = [
1198+
level for level in levels_list if level in selected_values
1199+
]
11951200

1196-
# filter categories if values is provided
1197-
if self.values is not None:
1198-
self.categories_ = {
1199-
c: self.categories_[c] for c in self.values if c in self.categories_
1200-
}
1201-
self.new_feature_names_ = {
1202-
c: self.new_feature_names_[c]
1203-
for c in self.values
1204-
if c in self.new_feature_names_
1205-
}
1201+
if levels_list:
1202+
self.categories_[c] = levels_list
1203+
1204+
self.new_feature_names_[c] = self._get_feature_names(column=c)
12061205

1207-
# checks if column in 'values' exist in categories
1208-
if not self.categories_:
1209-
error_message = "No valid columns in 'values' for encoding"
1210-
raise ValueError(error_message)
1206+
if not self.categories_:
1207+
error_message = "No valid categories found in 'values' for encoding"
1208+
raise ValueError(error_message)
12111209

12121210
return self
12131211

0 commit comments

Comments
 (0)
Please sign in to comment.