Skip to content

Commit 3e63edf

Browse files
committed
add load and save in json
1 parent 93d304d commit 3e63edf

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

pipeline_lib/core/data_container.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class DataContainer:
3333
PREDICTIONS = "predictions"
3434
EXPLAINER = "explainer"
3535
TUNING_PARAMS = "tuning_params"
36+
TARGET = "target"
3637

3738
def __init__(self, initial_data: Optional[dict] = None):
3839
"""
@@ -184,6 +185,7 @@ def save(self, file_path: str, keys: Optional[Union[str, list[str]]] = None):
184185

185186
with open(file_path, "wb") as file:
186187
file.write(serialized_data)
188+
187189
self.logger.info(
188190
f"{self.__class__.__name__} serialized and saved to {file_path}. Size:"
189191
f" {data_size_mb:.2f} MB"

pipeline_lib/core/pipeline.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, initial_data: Optional[DataContainer] = None):
2222
if not all(isinstance(step, PipelineStep) for step in self.steps):
2323
raise TypeError("All steps must be instances of PipelineStep")
2424
self.initial_data = initial_data
25+
self.save_path = None
26+
self.load_path = None
2527

2628
@classmethod
2729
def register_step(cls, step_class):
@@ -102,11 +104,15 @@ def load_and_register_custom_steps(cls, custom_steps_path: str) -> None:
102104
def run(self) -> DataContainer:
103105
"""Run the pipeline on the given data."""
104106

105-
data = DataContainer()
107+
data = DataContainer.from_pickle(self.load_path) if self.load_path else DataContainer()
106108

107109
for i, step in enumerate(self.steps):
108110
Pipeline.logger.info(f"Running {step.__class__.__name__} - {i + 1} / {len(self.steps)}")
109111
data = step.execute(data)
112+
113+
if self.save_path:
114+
data.save(self.save_path)
115+
110116
return data
111117

112118
@classmethod
@@ -124,6 +130,10 @@ def from_json(cls, path: str) -> Pipeline:
124130
Pipeline.load_and_register_custom_steps(custom_steps_path)
125131

126132
pipeline = Pipeline()
133+
134+
pipeline.load_path = config.get("load_path")
135+
pipeline.save_path = config.get("save_path")
136+
127137
steps = []
128138

129139
for step_config in config["pipeline"]["steps"]:

pipeline_lib/core/steps/calculate_metrics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def __init__(self, config: Optional[dict] = None) -> None:
1919
def execute(self, data: DataContainer) -> DataContainer:
2020
self.logger.debug("Starting metric calculation")
2121
model_output = data[DataContainer.MODEL_OUTPUT]
22-
model_configs = data[DataContainer.MODEL_CONFIGS]
23-
target_column_name = model_configs.get("target")
22+
23+
target_column_name = data.get(DataContainer.TARGET)
2424

2525
if target_column_name is None:
26-
raise ValueError("Target column not found in model_configs.")
26+
raise ValueError("Target column not found on any configuration.")
2727

2828
true_values = model_output[target_column_name]
2929
predictions = model_output[DataContainer.PREDICTIONS]
@@ -32,5 +32,6 @@ def execute(self, data: DataContainer) -> DataContainer:
3232
rmse = np.sqrt(mean_squared_error(true_values, predictions))
3333

3434
results = {"MAE": str(mae), "RMSE": str(rmse)}
35+
self.logger.info(results)
3536
data[DataContainer.METRICS] = results
3637
return data

pipeline_lib/implementation/tabular/xgboost/fit_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def execute(self, data: DataContainer) -> DataContainer:
2727
if target is None:
2828
raise ValueError("Target column not found in model_configs.")
2929

30+
data[DataContainer.TARGET] = target
31+
3032
df_train = data[DataContainer.TRAIN]
3133
df_valid = data[DataContainer.VALIDATION]
3234

pipeline_lib/implementation/tabular/xgboost/predict.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ def execute(self, data: DataContainer) -> DataContainer:
1919
if not isinstance(model_input, pd.DataFrame):
2020
raise ValueError("model_input must be a pandas DataFrame.")
2121

22-
model_configs = data.get(DataContainer.MODEL_CONFIGS)
23-
if model_configs:
24-
drop_columns: list[str] = model_configs.get("drop_columns")
22+
if self.config:
23+
drop_columns = self.config.get("drop_columns")
2524
if drop_columns:
2625
model_input = model_input.drop(columns=drop_columns)
27-
target = model_configs.get("target")
26+
27+
target = self.config.get("target")
2828
if target is None:
2929
raise ValueError("Target column not found in model_configs.")
30+
data[DataContainer.TARGET] = target
3031

3132
predictions = model.predict(model_input.drop(columns=[target]))
3233
else:

0 commit comments

Comments
 (0)