Skip to content

Commit e284327

Browse files
committed
check supported encoders in __init__ instead of exectue
1 parent b755bb2 commit e284327

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

pipeline_lib/core/steps/encode.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(
2828
self.init_logger()
2929
self.target = target
3030
self.cardinality_threshold = cardinality_threshold
31-
self.low_cardinality_encoder = low_cardinality_encoder
32-
self.high_cardinality_encoder = high_cardinality_encoder
31+
self.high_cardinality_encoder = self._get_encoder(high_cardinality_encoder)
32+
self.low_cardinality_encoder = self._get_encoder(low_cardinality_encoder)
3333
self.encoder_feature_map = {}
3434

3535
def execute(self, data: DataContainer) -> DataContainer:
@@ -123,29 +123,34 @@ def _create_column_transformer(
123123
self, high_cardinality_features: List[str], low_cardinality_features: List[str]
124124
) -> ColumnTransformer:
125125
"""Create a ColumnTransformer for encoding."""
126-
high_cardinality_encoder = self._get_encoder(self.high_cardinality_encoder)
127-
low_cardinality_encoder = self._get_encoder(self.low_cardinality_encoder)
128126

129127
# Initialize the encoder_feature_map as an empty dictionary
130128
self.encoder_feature_map = {}
131129

130+
high_cardinality_encoder_name = self.high_cardinality_encoder.__class__.__name__
131+
low_cardinality_encoder_name = self.low_cardinality_encoder.__class__.__name__
132+
132133
# Check if both encoders are the same
133-
if self.high_cardinality_encoder == self.low_cardinality_encoder:
134+
if high_cardinality_encoder_name == low_cardinality_encoder_name:
134135
# If the same, merge the feature lists
135136
# This assumes you want to combine the features into a single list; adjust if needed
136137
combined_features = high_cardinality_features + low_cardinality_features
137-
self.encoder_feature_map[self.high_cardinality_encoder] = combined_features
138+
self.encoder_feature_map[high_cardinality_encoder_name] = combined_features
138139
else:
139140
# If not the same, assign individually
140-
self.encoder_feature_map[self.high_cardinality_encoder] = high_cardinality_features
141-
self.encoder_feature_map[self.low_cardinality_encoder] = low_cardinality_features
141+
self.encoder_feature_map[high_cardinality_encoder_name] = high_cardinality_features
142+
self.encoder_feature_map[low_cardinality_encoder_name] = low_cardinality_features
142143

143144
self.logger.info(f"Encoder feature map: \n{json.dumps(self.encoder_feature_map, indent=4)}")
144145

145146
return ColumnTransformer(
146147
[
147-
("high_cardinality_encoder", high_cardinality_encoder, high_cardinality_features),
148-
("low_cardinality_encoder", low_cardinality_encoder, low_cardinality_features),
148+
(
149+
"high_cardinality_encoder",
150+
self.high_cardinality_encoder,
151+
high_cardinality_features,
152+
),
153+
("low_cardinality_encoder", self.low_cardinality_encoder, low_cardinality_features),
149154
],
150155
remainder="passthrough",
151156
verbose_feature_names_out=False,

0 commit comments

Comments
 (0)