Skip to content

Commit

Permalink
add load and save in json
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 5, 2024
1 parent 93d304d commit 3e63edf
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 8 deletions.
2 changes: 2 additions & 0 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DataContainer:
PREDICTIONS = "predictions"
EXPLAINER = "explainer"
TUNING_PARAMS = "tuning_params"
TARGET = "target"

def __init__(self, initial_data: Optional[dict] = None):
"""
Expand Down Expand Up @@ -184,6 +185,7 @@ def save(self, file_path: str, keys: Optional[Union[str, list[str]]] = None):

with open(file_path, "wb") as file:
file.write(serialized_data)

self.logger.info(
f"{self.__class__.__name__} serialized and saved to {file_path}. Size:"
f" {data_size_mb:.2f} MB"
Expand Down
12 changes: 11 additions & 1 deletion pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self, initial_data: Optional[DataContainer] = None):
if not all(isinstance(step, PipelineStep) for step in self.steps):
raise TypeError("All steps must be instances of PipelineStep")
self.initial_data = initial_data
self.save_path = None
self.load_path = None

@classmethod
def register_step(cls, step_class):
Expand Down Expand Up @@ -102,11 +104,15 @@ def load_and_register_custom_steps(cls, custom_steps_path: str) -> None:
def run(self) -> DataContainer:
"""Run the pipeline on the given data."""

data = DataContainer()
data = DataContainer.from_pickle(self.load_path) if self.load_path else DataContainer()

for i, step in enumerate(self.steps):
Pipeline.logger.info(f"Running {step.__class__.__name__} - {i + 1} / {len(self.steps)}")
data = step.execute(data)

if self.save_path:
data.save(self.save_path)

return data

@classmethod
Expand All @@ -124,6 +130,10 @@ def from_json(cls, path: str) -> Pipeline:
Pipeline.load_and_register_custom_steps(custom_steps_path)

pipeline = Pipeline()

pipeline.load_path = config.get("load_path")
pipeline.save_path = config.get("save_path")

steps = []

for step_config in config["pipeline"]["steps"]:
Expand Down
7 changes: 4 additions & 3 deletions pipeline_lib/core/steps/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def __init__(self, config: Optional[dict] = None) -> None:
def execute(self, data: DataContainer) -> DataContainer:
self.logger.debug("Starting metric calculation")
model_output = data[DataContainer.MODEL_OUTPUT]
model_configs = data[DataContainer.MODEL_CONFIGS]
target_column_name = model_configs.get("target")

target_column_name = data.get(DataContainer.TARGET)

if target_column_name is None:
raise ValueError("Target column not found in model_configs.")
raise ValueError("Target column not found on any configuration.")

true_values = model_output[target_column_name]
predictions = model_output[DataContainer.PREDICTIONS]
Expand All @@ -32,5 +32,6 @@ def execute(self, data: DataContainer) -> DataContainer:
rmse = np.sqrt(mean_squared_error(true_values, predictions))

results = {"MAE": str(mae), "RMSE": str(rmse)}
self.logger.info(results)
data[DataContainer.METRICS] = results
return data
2 changes: 2 additions & 0 deletions pipeline_lib/implementation/tabular/xgboost/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def execute(self, data: DataContainer) -> DataContainer:
if target is None:
raise ValueError("Target column not found in model_configs.")

data[DataContainer.TARGET] = target

df_train = data[DataContainer.TRAIN]
df_valid = data[DataContainer.VALIDATION]

Expand Down
9 changes: 5 additions & 4 deletions pipeline_lib/implementation/tabular/xgboost/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def execute(self, data: DataContainer) -> DataContainer:
if not isinstance(model_input, pd.DataFrame):
raise ValueError("model_input must be a pandas DataFrame.")

model_configs = data.get(DataContainer.MODEL_CONFIGS)
if model_configs:
drop_columns: list[str] = model_configs.get("drop_columns")
if self.config:
drop_columns = self.config.get("drop_columns")
if drop_columns:
model_input = model_input.drop(columns=drop_columns)
target = model_configs.get("target")

target = self.config.get("target")
if target is None:
raise ValueError("Target column not found in model_configs.")
data[DataContainer.TARGET] = target

predictions = model.predict(model_input.drop(columns=[target]))
else:
Expand Down

0 comments on commit 3e63edf

Please sign in to comment.