Skip to content

Commit

Permalink
save xgboost model to joblib
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 6, 2024
1 parent a00981f commit e4ddec5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
2 changes: 0 additions & 2 deletions pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class Pipeline:

def __init__(self, initial_data: Optional[DataContainer] = None):
self.steps = []
if not all(isinstance(step, PipelineStep) for step in self.steps):
raise TypeError("All steps must be instances of PipelineStep")
self.initial_data = initial_data
self.save_path = None
self.load_path = None
Expand Down
10 changes: 10 additions & 0 deletions pipeline_lib/implementation/tabular/xgboost/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import optuna
import xgboost as xgb
from joblib import dump
from optuna.pruners import MedianPruner
from sklearn.metrics import mean_absolute_error

Expand Down Expand Up @@ -79,6 +80,15 @@ def execute(self, data: DataContainer) -> DataContainer:
importance = dict(sorted(importance.items(), key=lambda item: item[1], reverse=True))
data[DataContainer.IMPORTANCE] = importance

# save model to disk
save_path = model_configs.get("save_path")

if save_path:
if not save_path.endswith(".joblib"):
raise ValueError("Only joblib format is supported for saving the model.")
self.logger.info(f"Saving the model to {save_path}")
dump(model, save_path)

end_time = time.time()
elapsed_time = end_time - start_time
minutes = int(elapsed_time // 60)
Expand Down
14 changes: 11 additions & 3 deletions pipeline_lib/implementation/tabular/xgboost/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
from joblib import load

from pipeline_lib.core import DataContainer
from pipeline_lib.core.steps import PredictStep
Expand All @@ -10,9 +11,16 @@ class XGBoostPredictStep(PredictStep):
def execute(self, data: DataContainer) -> DataContainer:
self.logger.debug("Obtaining predictions for XGBoost model.")

model = data[DataContainer.MODEL]
if model is None:
raise Exception("Model not trained yet.")
if not self.config:
raise ValueError("No prediction configs found.")

load_path = self.config.get("load_path")
if not load_path:
raise ValueError("No load path found in model_configs.")

if not load_path.endswith(".joblib"):
raise ValueError("Only joblib format is supported for loading the model.")
model = load(load_path)

model_input = data[DataContainer.CLEAN]

Expand Down

0 comments on commit e4ddec5

Please sign in to comment.