Skip to content

Commit 2d9d259

Browse files
committed
personalize encoders in json
1 parent 3703bdf commit 2d9d259

File tree

1 file changed

+62
-24
lines changed

1 file changed

+62
-24
lines changed

pipeline_lib/core/steps/encode.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Optional, Tuple
1+
import json
2+
from typing import List, Optional, Tuple, Union
23

34
import numpy as np
45
import pandas as pd
@@ -16,11 +17,20 @@ class EncodeStep(PipelineStep):
1617
used_for_prediction = True
1718
used_for_training = True
1819

19-
def __init__(self, target: Optional[str] = None, cardinality_threshold: float = 0.3) -> None:
20+
def __init__(
21+
self,
22+
target: Optional[str] = None,
23+
cardinality_threshold: float = 0.3,
24+
low_cardinality_encoder: str = "OrdinalEncoder",
25+
high_cardinality_encoder: str = "TargetEncoder",
26+
) -> None:
2027
"""Initialize EncodeStep."""
2128
self.init_logger()
2229
self.target = target
2330
self.cardinality_threshold = cardinality_threshold
31+
self.low_cardinality_encoder = low_cardinality_encoder
32+
self.high_cardinality_encoder = high_cardinality_encoder
33+
self.encoder_feature_map = {}
2434

2535
def execute(self, data: DataContainer) -> DataContainer:
2636
"""Execute the encoding step."""
@@ -39,28 +49,26 @@ def execute(self, data: DataContainer) -> DataContainer:
3949
if pd.api.types.is_numeric_dtype(df[target_column_name]):
4050
target_original_dtype = df[target_column_name].dtype
4151

42-
self._log_feature_info(
43-
categorical_features,
44-
numeric_features,
45-
low_cardinality_features,
46-
high_cardinality_features,
47-
)
48-
4952
column_transformer = self._create_column_transformer(
5053
high_cardinality_features, low_cardinality_features
5154
)
5255

5356
encoded_data = self._transform_data(df, target_column_name, column_transformer)
5457
encoded_data = self._restore_column_order(df, encoded_data)
55-
encoded_data = self._convert_ordinal_encoded_columns_to_int(
56-
encoded_data, column_transformer
57-
)
58+
encoded_data = self._convert_ordinal_encoded_columns_to_int(encoded_data)
5859
encoded_data = self._restore_numeric_dtypes(encoded_data, original_numeric_dtypes)
5960
encoded_data = self._restore_target_dtype(
6061
encoded_data, target_column_name, target_original_dtype
6162
)
6263
encoded_data = self._convert_float64_to_float32(encoded_data)
6364

65+
self._log_feature_info(
66+
categorical_features,
67+
numeric_features,
68+
low_cardinality_features,
69+
high_cardinality_features,
70+
)
71+
6472
data.flow = encoded_data
6573

6674
return data
@@ -93,14 +101,50 @@ def _split_categorical_features(
93101
]
94102
return low_cardinality_features, high_cardinality_features
95103

104+
def _get_encoder(self, encoder_name: str) -> Union[OrdinalEncoder, TargetEncoder]:
105+
"""Map encoder name to the corresponding encoder class."""
106+
encoder_map = {
107+
"OrdinalEncoder": OrdinalEncoder(),
108+
"TargetEncoder": TargetEncoder(),
109+
# Add more encoders as needed
110+
}
111+
112+
encoder = encoder_map.get(encoder_name)
113+
114+
if not encoder:
115+
raise ValueError(
116+
f"Unsupported encoder: {encoder_name}. Supported encoders: {encoder_map}"
117+
)
118+
119+
return encoder
120+
96121
def _create_column_transformer(
97122
self, high_cardinality_features: List[str], low_cardinality_features: List[str]
98123
) -> ColumnTransformer:
99124
"""Create a ColumnTransformer for encoding."""
125+
high_cardinality_encoder = self._get_encoder(self.high_cardinality_encoder)
126+
low_cardinality_encoder = self._get_encoder(self.low_cardinality_encoder)
127+
128+
# Initialize the encoder_feature_map as an empty dictionary
129+
self.encoder_feature_map = {}
130+
131+
# Check if both encoders are the same
132+
if self.high_cardinality_encoder == self.low_cardinality_encoder:
133+
# If the same, merge the feature lists
134+
# This assumes you want to combine the features into a single list; adjust if needed
135+
combined_features = high_cardinality_features + low_cardinality_features
136+
self.encoder_feature_map[self.high_cardinality_encoder] = combined_features
137+
else:
138+
# If not the same, assign individually
139+
self.encoder_feature_map[self.high_cardinality_encoder] = high_cardinality_features
140+
self.encoder_feature_map[self.low_cardinality_encoder] = low_cardinality_features
141+
142+
print(self.encoder_feature_map)
143+
100144
return ColumnTransformer(
101145
[
102-
("target_encoder", TargetEncoder(), high_cardinality_features),
103-
("ordinal_encoder", OrdinalEncoder(), low_cardinality_features),
146+
("high_cardinality_encoder", high_cardinality_encoder, high_cardinality_features),
147+
("low_cardinality_encoder", low_cardinality_encoder, low_cardinality_features),
104148
],
105149
remainder="passthrough",
106150
verbose_feature_names_out=False,
@@ -120,15 +164,11 @@ def _restore_column_order(self, df: pd.DataFrame, encoded_data: pd.DataFrame) ->
120164
new_column_order = [col for col in df.columns if col in encoded_data.columns]
121165
return encoded_data[new_column_order]
122166

123-
def _convert_ordinal_encoded_columns_to_int(
124-
self, encoded_data: pd.DataFrame, column_transformer: ColumnTransformer
125-
) -> pd.DataFrame:
167+
def _convert_ordinal_encoded_columns_to_int(self, encoded_data: pd.DataFrame) -> pd.DataFrame:
126168
"""Convert ordinal encoded columns to the smallest possible integer dtype."""
127-
ordinal_encoder_features = column_transformer.named_transformers_[
128-
"ordinal_encoder"
129-
].get_feature_names_out()
169+
ordinal_encoded_features = self.encoder_feature_map.get("OrdinalEncoder", [])
130170

131-
for col in ordinal_encoder_features:
171+
for col in ordinal_encoded_features:
132172
if col in encoded_data.columns:
133173
n_unique = encoded_data[col].nunique()
134174
if n_unique <= 2**8:
@@ -199,11 +239,9 @@ def _log_feature_info(
199239
f"Low cardinality features (cardinality ratio < {self.cardinality_threshold}):"
200240
f" ({len(low_cardinality_features)}) - {low_cardinality_features}"
201241
)
202-
self.logger.info("Low cardinality features encoding method: ordinal encoder")
203242
self.logger.info(
204243
f"High cardinality features (cardinality ratio >= {self.cardinality_threshold}):"
205244
f" ({len(high_cardinality_features)}) - {high_cardinality_features}"
206245
)
207-
self.logger.info("High cardinality features encoding method: target encoder")
208246
self.logger.info(f"Numeric features: ({len(numeric_features)}) - {numeric_features}")
209-
self.logger.info("Numeric features encoding method: passthrough")
247+
self.logger.info(f"Encoder feature map: \n{json.dumps(self.encoder_feature_map, indent=4)}")

0 commit comments

Comments
 (0)