Skip to content

Commit 10c0afa

Browse files
committed
change saving method to save fitted encoders
1 parent e284327 commit 10c0afa

File tree

5 files changed

+74
-42
lines changed

5 files changed

+74
-42
lines changed

pipeline_lib/core/data_container.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pandas as pd
1212
import yaml
13+
from sklearn.compose import ColumnTransformer
1314

1415
from pipeline_lib.core.model import Model
1516

@@ -166,7 +167,7 @@ def save(self, file_path: str, keys: Optional[Union[str, list[str]]] = None):
166167
if isinstance(keys, str):
167168
keys = [keys]
168169

169-
data_to_save = {k: self.data[k] for k in keys} if keys else self.data
170+
data_to_save = {k: self.data.get(k) for k in keys} if keys else self.data
170171

171172
serialized_data = pickle.dumps(data_to_save)
172173
data_size_bytes = sys.getsizeof(serialized_data)
@@ -619,6 +620,30 @@ def is_train(self, value: bool):
619620
"""
620621
self["is_train"] = value
621622

623+
@property
624+
def _encoder(self) -> ColumnTransformer:
625+
"""
626+
Get the encoder from the DataContainer.
627+
628+
Returns
629+
-------
630+
ColumnTransformer
631+
The encoder stored in the DataContainer.
632+
"""
633+
return self["encoder"]
634+
635+
@_encoder.setter
636+
def _encoder(self, value: ColumnTransformer):
637+
"""
638+
Set the encoder in the DataContainer.
639+
640+
Parameters
641+
----------
642+
value
643+
The encoder to be stored in the DataContainer.
644+
"""
645+
self["encoder"] = value
646+
622647
def __eq__(self, other) -> bool:
623648
"""
624649
Compare this DataContainer with another for equality.

pipeline_lib/core/pipeline.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class Pipeline:
2020
step_registry = StepRegistry()
2121
model_registry = ModelRegistry()
2222

23+
KEYS_TO_SAVE = ["model", "encoder", "_drop_columns", "target"]
24+
2325
def __init__(self, initial_data: Optional[DataContainer] = None):
2426
self.steps = []
2527
self.initial_data = initial_data
26-
self.save_path = None
27-
self.load_path = None
28-
self.model_path = None
2928
self.config = None
29+
self.save_data_path = None
3030

3131
def add_steps(self, steps: list[PipelineStep]):
3232
"""Add steps to the pipeline."""
@@ -35,22 +35,32 @@ def add_steps(self, steps: list[PipelineStep]):
3535
def run(self, is_train: bool, save: bool = True) -> DataContainer:
3636
"""Run the pipeline on the given data."""
3737

38-
data = DataContainer.from_pickle(self.load_path) if self.load_path else DataContainer()
39-
data.is_train = is_train
38+
if not self.save_data_path:
39+
raise ValueError(
40+
"A path for saving the data must be provided. Use the `save_data_path` attribute."
41+
)
42+
43+
data = DataContainer()
4044

4145
if is_train:
4246
steps_to_run = [step for step in self.steps if step.used_for_training]
4347
self.logger.info("Training the pipeline")
4448
else:
49+
data = DataContainer.from_pickle(self.save_data_path)
4550
steps_to_run = [step for step in self.steps if step.used_for_prediction]
4651
self.logger.info("Predicting with the pipeline")
4752

53+
data.is_train = is_train
54+
4855
for i, step in enumerate(steps_to_run):
4956
Pipeline.logger.info(
5057
f"Running {step.__class__.__name__} - {i + 1} / {len(steps_to_run)}"
5158
)
5259
data = step.execute(data)
5360

61+
if is_train:
62+
data.save(self.save_data_path, keys=self.KEYS_TO_SAVE)
63+
5464
if save:
5565
self.save_run(data)
5666

@@ -78,17 +88,16 @@ def from_json(cls, path: str) -> Pipeline:
7888
if custom_steps_path:
7989
cls.step_registry.load_and_register_custom_steps(custom_steps_path)
8090

91+
save_data_path = config["pipeline"].get("save_data_path")
92+
8193
pipeline = Pipeline()
8294

83-
pipeline.load_path = config.get("load_path")
84-
pipeline.save_path = config.get("save_path")
8595
pipeline.config = config
96+
pipeline.save_data_path = save_data_path
8697

87-
steps = []
98+
print(f"Saved data path: {save_data_path}")
8899

89-
model_path = None
90-
drop_columns = None
91-
target = None
100+
steps = []
92101

93102
for step_config in config["pipeline"]["steps"]:
94103
step_type = step_config["step_type"]
@@ -103,15 +112,6 @@ def from_json(cls, path: str) -> Pipeline:
103112
model_class_name = parameters.pop("model_class")
104113
model_class = cls.model_registry.get_model_class(model_class_name)
105114
parameters["model_class"] = model_class
106-
model_path = parameters.get("save_path")
107-
drop_columns = parameters.get("drop_columns")
108-
target = parameters.get("target")
109-
110-
# if step type is prediction, add model path
111-
if step_type == "PredictStep":
112-
parameters["load_path"] = model_path
113-
parameters["drop_columns"] = drop_columns
114-
parameters["target"] = target
115115

116116
step_class = cls.step_registry.get_step_class(step_type)
117117
step = step_class(**parameters)

pipeline_lib/core/steps/encode.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def execute(self, data: DataContainer) -> DataContainer:
3636
"""Execute the encoding step."""
3737
self.logger.info("Encoding data")
3838
df = data.flow
39+
40+
if not data.target and not self.target:
41+
raise ValueError("Target column not found in any parameter before encoding.")
42+
3943
target_column_name = self.target or data.target
4044
target_original_dtype = None
4145

@@ -49,11 +53,21 @@ def execute(self, data: DataContainer) -> DataContainer:
4953
if pd.api.types.is_numeric_dtype(df[target_column_name]):
5054
target_original_dtype = df[target_column_name].dtype
5155

52-
column_transformer = self._create_column_transformer(
53-
high_cardinality_features, low_cardinality_features
54-
)
56+
if data.is_train:
57+
column_transformer = self._create_column_transformer(
58+
high_cardinality_features, low_cardinality_features
59+
)
60+
# Save the encoder for prediction
61+
data._encoder = column_transformer
62+
else:
63+
column_transformer = data._encoder
5564

56-
encoded_data = self._transform_data(df, target_column_name, column_transformer)
65+
encoded_data = self._transform_data(
66+
df,
67+
target_column_name,
68+
column_transformer,
69+
data.is_train,
70+
)
5771
encoded_data = self._restore_column_order(df, encoded_data)
5872
encoded_data = self._convert_ordinal_encoded_columns_to_int(encoded_data)
5973
encoded_data = self._restore_numeric_dtypes(encoded_data, original_numeric_dtypes)
@@ -157,10 +171,15 @@ def _create_column_transformer(
157171
)
158172

159173
def _transform_data(
160-
self, df: pd.DataFrame, target_column_name: str, column_transformer: ColumnTransformer
174+
self,
175+
df: pd.DataFrame,
176+
target_column_name: str,
177+
column_transformer: ColumnTransformer,
178+
is_train: bool,
161179
) -> pd.DataFrame:
162180
"""Transform the data using the ColumnTransformer."""
163-
column_transformer.fit(df, df[target_column_name])
181+
if is_train:
182+
column_transformer.fit(df, df[target_column_name])
164183
transformed_data = column_transformer.transform(df)
165184
self.logger.debug(f"Transformed data shape: {transformed_data.shape}")
166185
return pd.DataFrame(transformed_data, columns=column_transformer.get_feature_names_out())

pipeline_lib/core/steps/explainer_dashboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def execute(self, data: DataContainer) -> DataContainer:
3939
self.logger.info(f"Sampling {self.max_samples} data points from the dataset.")
4040
df = df.sample(n=self.max_samples, random_state=42)
4141

42-
drop_columns = data._drop_columns
42+
drop_columns = data._drop_columns + ["predictions"]
4343
if drop_columns:
4444
df = df.drop(columns=drop_columns)
4545

pipeline_lib/core/steps/predict.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from typing import List, Optional
2-
31
from pipeline_lib.core import DataContainer
4-
from pipeline_lib.core.model import Model
52
from pipeline_lib.core.steps.base import PipelineStep
63

74

@@ -13,23 +10,16 @@ class PredictStep(PipelineStep):
1310

1411
def __init__(
1512
self,
16-
load_path: str,
17-
target: str,
18-
drop_columns: Optional[List[str]] = None,
1913
) -> None:
2014
"""Initialize Predict Step."""
2115
super().__init__()
2216
self.init_logger()
23-
self.load_path = load_path
24-
self.model = Model.from_file(load_path)
25-
self.target = target
26-
self.drop_columns = drop_columns or []
2717

2818
def execute(self, data: DataContainer) -> DataContainer:
2919
"""Execute the step."""
3020
self.logger.info("Obtaining predictions")
3121

32-
drop_columns = self.drop_columns + [self.target]
22+
drop_columns = data._drop_columns + [data.target]
3323

3424
missing_columns = [col for col in drop_columns if col not in data.flow.columns]
3525
if missing_columns:
@@ -39,10 +29,8 @@ def execute(self, data: DataContainer) -> DataContainer:
3929
self.logger.warning(error_message)
4030
raise KeyError(error_message)
4131

42-
data.predictions = self.model.predict(data.flow.drop(columns=drop_columns))
32+
data.predictions = data.model.predict(data.flow.drop(columns=drop_columns))
4333

4434
data.flow["predictions"] = data.predictions
4535

46-
data.target = self.target
47-
4836
return data

0 commit comments

Comments
 (0)