1
- from typing import List , Optional , Tuple
1
+ import json
2
+ from typing import List , Optional , Tuple , Union
2
3
3
4
import numpy as np
4
5
import pandas as pd
@@ -16,11 +17,20 @@ class EncodeStep(PipelineStep):
16
17
used_for_prediction = True
17
18
used_for_training = True
18
19
19
- def __init__ (self , target : Optional [str ] = None , cardinality_threshold : float = 0.3 ) -> None :
20
+ def __init__ (
21
+ self ,
22
+ target : Optional [str ] = None ,
23
+ cardinality_threshold : float = 0.3 ,
24
+ low_cardinality_encoder : str = "OrdinalEncoder" ,
25
+ high_cardinality_encoder : str = "TargetEncoder" ,
26
+ ) -> None :
20
27
"""Initialize EncodeStep."""
21
28
self .init_logger ()
22
29
self .target = target
23
30
self .cardinality_threshold = cardinality_threshold
31
+ self .low_cardinality_encoder = low_cardinality_encoder
32
+ self .high_cardinality_encoder = high_cardinality_encoder
33
+ self .encoder_feature_map = {}
24
34
25
35
def execute (self , data : DataContainer ) -> DataContainer :
26
36
"""Execute the encoding step."""
@@ -39,28 +49,26 @@ def execute(self, data: DataContainer) -> DataContainer:
39
49
if pd .api .types .is_numeric_dtype (df [target_column_name ]):
40
50
target_original_dtype = df [target_column_name ].dtype
41
51
42
- self ._log_feature_info (
43
- categorical_features ,
44
- numeric_features ,
45
- low_cardinality_features ,
46
- high_cardinality_features ,
47
- )
48
-
49
52
column_transformer = self ._create_column_transformer (
50
53
high_cardinality_features , low_cardinality_features
51
54
)
52
55
53
56
encoded_data = self ._transform_data (df , target_column_name , column_transformer )
54
57
encoded_data = self ._restore_column_order (df , encoded_data )
55
- encoded_data = self ._convert_ordinal_encoded_columns_to_int (
56
- encoded_data , column_transformer
57
- )
58
+ encoded_data = self ._convert_ordinal_encoded_columns_to_int (encoded_data )
58
59
encoded_data = self ._restore_numeric_dtypes (encoded_data , original_numeric_dtypes )
59
60
encoded_data = self ._restore_target_dtype (
60
61
encoded_data , target_column_name , target_original_dtype
61
62
)
62
63
encoded_data = self ._convert_float64_to_float32 (encoded_data )
63
64
65
+ self ._log_feature_info (
66
+ categorical_features ,
67
+ numeric_features ,
68
+ low_cardinality_features ,
69
+ high_cardinality_features ,
70
+ )
71
+
64
72
data .flow = encoded_data
65
73
66
74
return data
@@ -93,14 +101,50 @@ def _split_categorical_features(
93
101
]
94
102
return low_cardinality_features , high_cardinality_features
95
103
104
+ def _get_encoder (self , encoder_name : str ) -> Union [OrdinalEncoder , TargetEncoder ]:
105
+ """Map encoder name to the corresponding encoder class."""
106
+ encoder_map = {
107
+ "OrdinalEncoder" : OrdinalEncoder (),
108
+ "TargetEncoder" : TargetEncoder (),
109
+ # Add more encoders as needed
110
+ }
111
+
112
+ encoder = encoder_map .get (encoder_name )
113
+
114
+ if not encoder :
115
+ raise ValueError (
116
+ f"Unsupported encoder: { encoder_name } . Supported encoders: { encoder_map } "
117
+ )
118
+
119
+ return encoder
120
+
96
121
def _create_column_transformer (
97
122
self , high_cardinality_features : List [str ], low_cardinality_features : List [str ]
98
123
) -> ColumnTransformer :
99
124
"""Create a ColumnTransformer for encoding."""
125
+ high_cardinality_encoder = self ._get_encoder (self .high_cardinality_encoder )
126
+ low_cardinality_encoder = self ._get_encoder (self .low_cardinality_encoder )
127
+
128
+ # Initialize the encoder_feature_map as an empty dictionary
129
+ self .encoder_feature_map = {}
130
+
131
+ # Check if both encoders are the same
132
+ if self .high_cardinality_encoder == self .low_cardinality_encoder :
133
+ # If the same, merge the feature lists
134
+ # This assumes you want to combine the features into a single list; adjust if needed
135
+ combined_features = high_cardinality_features + low_cardinality_features
136
+ self .encoder_feature_map [self .high_cardinality_encoder ] = combined_features
137
+ else :
138
+ # If not the same, assign individually
139
+ self .encoder_feature_map [self .high_cardinality_encoder ] = high_cardinality_features
140
+ self .encoder_feature_map [self .low_cardinality_encoder ] = low_cardinality_features
141
+
142
+ print (self .encoder_feature_map )
143
+
100
144
return ColumnTransformer (
101
145
[
102
- ("target_encoder " , TargetEncoder () , high_cardinality_features ),
103
- ("ordinal_encoder " , OrdinalEncoder () , low_cardinality_features ),
146
+ ("high_cardinality_encoder " , high_cardinality_encoder , high_cardinality_features ),
147
+ ("low_cardinality_encoder " , low_cardinality_encoder , low_cardinality_features ),
104
148
],
105
149
remainder = "passthrough" ,
106
150
verbose_feature_names_out = False ,
@@ -120,15 +164,11 @@ def _restore_column_order(self, df: pd.DataFrame, encoded_data: pd.DataFrame) ->
120
164
new_column_order = [col for col in df .columns if col in encoded_data .columns ]
121
165
return encoded_data [new_column_order ]
122
166
123
- def _convert_ordinal_encoded_columns_to_int (
124
- self , encoded_data : pd .DataFrame , column_transformer : ColumnTransformer
125
- ) -> pd .DataFrame :
167
+ def _convert_ordinal_encoded_columns_to_int (self , encoded_data : pd .DataFrame ) -> pd .DataFrame :
126
168
"""Convert ordinal encoded columns to the smallest possible integer dtype."""
127
- ordinal_encoder_features = column_transformer .named_transformers_ [
128
- "ordinal_encoder"
129
- ].get_feature_names_out ()
169
+ ordinal_encoded_features = self .encoder_feature_map .get ("OrdinalEncoder" , [])
130
170
131
- for col in ordinal_encoder_features :
171
+ for col in ordinal_encoded_features :
132
172
if col in encoded_data .columns :
133
173
n_unique = encoded_data [col ].nunique ()
134
174
if n_unique <= 2 ** 8 :
@@ -199,11 +239,9 @@ def _log_feature_info(
199
239
f"Low cardinality features (cardinality ratio < { self .cardinality_threshold } ):"
200
240
f" ({ len (low_cardinality_features )} ) - { low_cardinality_features } "
201
241
)
202
- self .logger .info ("Low cardinality features encoding method: ordinal encoder" )
203
242
self .logger .info (
204
243
f"High cardinality features (cardinality ratio >= { self .cardinality_threshold } ):"
205
244
f" ({ len (high_cardinality_features )} ) - { high_cardinality_features } "
206
245
)
207
- self .logger .info ("High cardinality features encoding method: target encoder" )
208
246
self .logger .info (f"Numeric features: ({ len (numeric_features )} ) - { numeric_features } " )
209
- self .logger .info ("Numeric features encoding method: passthrough " )
247
+ self .logger .info (f"Encoder feature map: \n { json . dumps ( self . encoder_feature_map , indent = 4 ) } " )
0 commit comments