diff --git a/pipeline_lib/core/pipeline.py b/pipeline_lib/core/pipeline.py index 5b4e63c..258890e 100644 --- a/pipeline_lib/core/pipeline.py +++ b/pipeline_lib/core/pipeline.py @@ -19,8 +19,6 @@ class Pipeline: def __init__(self, initial_data: Optional[DataContainer] = None): self.steps = [] - if not all(isinstance(step, PipelineStep) for step in self.steps): - raise TypeError("All steps must be instances of PipelineStep") self.initial_data = initial_data self.save_path = None self.load_path = None diff --git a/pipeline_lib/implementation/tabular/xgboost/fit_model.py b/pipeline_lib/implementation/tabular/xgboost/fit_model.py index c2c834c..606e534 100644 --- a/pipeline_lib/implementation/tabular/xgboost/fit_model.py +++ b/pipeline_lib/implementation/tabular/xgboost/fit_model.py @@ -2,6 +2,7 @@ import optuna import xgboost as xgb +from joblib import dump from optuna.pruners import MedianPruner from sklearn.metrics import mean_absolute_error @@ -79,6 +80,15 @@ def execute(self, data: DataContainer) -> DataContainer: importance = dict(sorted(importance.items(), key=lambda item: item[1], reverse=True)) data[DataContainer.IMPORTANCE] = importance + # save model to disk + save_path = model_configs.get("save_path") + + if save_path: + if not save_path.endswith(".joblib"): + raise ValueError("Only joblib format is supported for saving the model.") + self.logger.info(f"Saving the model to {save_path}") + dump(model, save_path) + end_time = time.time() elapsed_time = end_time - start_time minutes = int(elapsed_time // 60) diff --git a/pipeline_lib/implementation/tabular/xgboost/predict.py b/pipeline_lib/implementation/tabular/xgboost/predict.py index f82323e..b0cd492 100644 --- a/pipeline_lib/implementation/tabular/xgboost/predict.py +++ b/pipeline_lib/implementation/tabular/xgboost/predict.py @@ -1,4 +1,5 @@ import pandas as pd +from joblib import load from pipeline_lib.core import DataContainer from pipeline_lib.core.steps import PredictStep @@ -10,9 +11,16 @@ class XGBoostPredictStep(PredictStep): def execute(self, data: DataContainer) -> DataContainer: self.logger.debug("Obtaining predictions for XGBoost model.") - model = data[DataContainer.MODEL] - if model is None: - raise Exception("Model not trained yet.") + if not self.config: + raise ValueError("No prediction configs found.") + + load_path = self.config.get("load_path") + if not load_path: + raise ValueError("No load path found in model_configs.") + + if not load_path.endswith(".joblib"): + raise ValueError("Only joblib format is supported for loading the model.") + model = load(load_path) model_input = data[DataContainer.CLEAN]