Skip to content

Commit b447fc9

Browse files
committed
change config dict in init to **params
1 parent 871bbec commit b447fc9

File tree

5 files changed

+93
-104
lines changed

5 files changed

+93
-104
lines changed

pipeline_lib/core/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def from_json(cls, path: str) -> Pipeline:
143143
)
144144

145145
step_class = Pipeline.get_step_class(step_type)
146-
step = step_class(config=parameters)
146+
step = step_class(**parameters)
147147
steps.append(step)
148148

149149
pipeline.add_steps(steps)

pipeline_lib/core/steps/calculate_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
class CalculateMetricsStep(PipelineStep):
1212
"""Calculate metrics."""
1313

14-
def __init__(self, config: Optional[dict] = None) -> None:
14+
def __init__(self) -> None:
1515
"""Initialize CalculateMetricsStep."""
16-
super().__init__(config=config)
16+
super().__init__()
1717
self.init_logger()
1818

1919
def execute(self, data: DataContainer) -> DataContainer:

pipeline_lib/core/steps/tabular_split.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,19 @@
1111
class TabularSplitStep(PipelineStep):
1212
"""Split the data."""
1313

14-
def __init__(self, config: Optional[dict] = None) -> None:
14+
def __init__(
15+
self,
16+
train_percentage: float,
17+
id_column: str,
18+
train_ids: Optional[list[str]] = None,
19+
validation_ids: Optional[list[str]] = None,
20+
) -> None:
1521
"""Initialize SplitStep."""
16-
super().__init__(config=config)
1722
self.init_logger()
23+
self.train_percentage = train_percentage
24+
self.id_column_name = id_column
25+
self.train_ids = train_ids
26+
self.validation_ids = validation_ids
1827

1928
def _id_based_split(
2029
self,
@@ -76,50 +85,24 @@ def execute(self, data: DataContainer) -> DataContainer:
7685
"""Execute the split based on IDs."""
7786
self.logger.info("Splitting tabular data...")
7887

79-
split_configs = self.config
80-
81-
if split_configs is None:
82-
self.logger.info("No split_configs found. No splitting will be performed.")
83-
return data
84-
8588
df = data[DataContainer.CLEAN]
86-
id_column_name = split_configs.get("id_column")
87-
if not id_column_name:
88-
raise ValueError("ID column name must be specified in split_configs.")
89-
90-
# check if both train_percentage and train_ids are provided
91-
if "train_percentage" in split_configs and "train_ids" in split_configs:
92-
raise ValueError(
93-
"Both train_percentage and train_ids cannot be provided in split_configs."
94-
)
95-
96-
# check if either train_percentage or train_ids are provided
97-
if "train_percentage" not in split_configs and "train_ids" not in split_configs:
98-
raise ValueError(
99-
"Either train_percentage or train_ids must be provided in split_configs."
100-
)
10189

102-
if "train_percentage" in split_configs:
103-
train_percentage = split_configs.get("train_percentage")
104-
if train_percentage is None or train_percentage <= 0 or train_percentage >= 1:
90+
if self.train_percentage:
91+
if (
92+
self.train_percentage is None
93+
or self.train_percentage <= 0
94+
or self.train_percentage >= 1
95+
):
10596
raise ValueError("train_percentage must be between 0 and 1.")
10697
train_ids, validation_ids = self._percentage_based_id_split(
107-
df, train_percentage, id_column_name
98+
df, self.train_percentage, self.id_column_name
10899
)
109-
else:
110-
train_ids = split_configs.get("train_ids")
111-
validation_ids = split_configs.get("validation_ids")
112-
if not train_ids or not validation_ids:
113-
raise ValueError(
114-
"Both train_ids and validation_ids must be provided in split_configs unless"
115-
" train_percentage is specified."
116-
)
117100

118101
self.logger.info(f"Number of train ids: {len(train_ids)}")
119102
self.logger.info(f"Number of validation ids: {len(validation_ids)}")
120103

121104
train_df, validation_df = self._id_based_split(
122-
df, train_ids, validation_ids, id_column_name
105+
df, train_ids, validation_ids, self.id_column_name
123106
)
124107

125108
train_rows = len(train_df)
@@ -134,7 +117,9 @@ def execute(self, data: DataContainer) -> DataContainer:
134117
f" {validation_rows/total_rows:.2%}"
135118
)
136119

137-
left_ids = df[~df[id_column_name].isin(train_ids + validation_ids)][id_column_name].unique()
120+
left_ids = df[~df[self.id_column_name].isin(train_ids + validation_ids)][
121+
self.id_column_name
122+
].unique()
138123
self.logger.info(f"Number of IDs left from total df: {len(left_ids)}")
139124
self.logger.debug(f"IDs left from total df: {left_ids}")
140125

pipeline_lib/implementation/tabular/xgboost/fit_model.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,70 @@
66
from optuna.pruners import MedianPruner
77
from sklearn.metrics import mean_absolute_error
88

9+
from typing import Optional
10+
911
from pipeline_lib.core import DataContainer
1012
from pipeline_lib.core.steps import FitModelStep
1113

1214

1315
class XGBoostFitModelStep(FitModelStep):
1416
"""Fit the model with XGBoost."""
1517

16-
def execute(self, data: DataContainer) -> DataContainer:
17-
self.logger.debug("Starting model fitting with XGBoost")
18+
def __init__(
19+
self,
20+
target: str,
21+
drop_columns: Optional[list[str]] = None,
22+
xgb_params: Optional[dict] = None,
23+
optuna_params: Optional[dict] = None,
24+
save_path: Optional[str] = None,
25+
) -> None:
26+
self.init_logger()
1827

19-
start_time = time.time()
28+
if target is None:
29+
raise ValueError("Target column not found in the parameters.")
2030

21-
model_configs = self.config
31+
self.target = target
32+
self.drop_columns = drop_columns
2233

23-
if model_configs is None:
24-
raise ValueError("No model configs found")
34+
if optuna_params and xgb_params:
35+
raise ValueError("Both optuna_params and xgb_params are defined. Please choose one.")
2536

26-
target = model_configs.get("target")
37+
if not optuna_params and not xgb_params:
38+
raise ValueError(
39+
"No parameters defined. Please define either optuna_params or xgb_params."
40+
)
2741

28-
if target is None:
29-
raise ValueError("Target column not found in model_configs.")
42+
self.xgb_params = xgb_params
43+
self.optuna_params = optuna_params
44+
self.save_path = save_path
3045

31-
data[DataContainer.TARGET] = target
46+
def execute(self, data: DataContainer) -> DataContainer:
47+
self.logger.debug("Starting model fitting with XGBoost")
48+
49+
start_time = time.time()
50+
51+
data[DataContainer.TARGET] = self.target
3252

3353
df_train = data[DataContainer.TRAIN]
3454
df_valid = data[DataContainer.VALIDATION]
3555

36-
drop_columns = model_configs.get("drop_columns")
37-
38-
if drop_columns:
39-
df_train = df_train.drop(columns=drop_columns)
40-
df_valid = df_valid.drop(columns=drop_columns)
56+
if self.drop_columns:
57+
df_train = df_train.drop(columns=self.drop_columns)
58+
df_valid = df_valid.drop(columns=self.drop_columns)
4159

4260
# Prepare the data
43-
X_train = df_train.drop(columns=[target])
44-
y_train = df_train[target]
61+
X_train = df_train.drop(columns=[self.target])
62+
y_train = df_train[self.target]
4563

46-
X_valid = df_valid.drop(columns=[target])
47-
y_valid = df_valid[target]
64+
X_valid = df_valid.drop(columns=[self.target])
65+
y_valid = df_valid[self.target]
4866

49-
optuna_params = model_configs.get("optuna_params")
50-
xgb_params = model_configs.get("xgb_params")
67+
params = self.xgb_params
5168

52-
if optuna_params and xgb_params:
53-
raise ValueError("Both optuna_params and xgb_params are defined. Please choose one.")
54-
55-
if not optuna_params and not xgb_params:
56-
raise ValueError(
57-
"No parameters defined. Please define either optuna_params or xgb_params."
69+
if self.optuna_params:
70+
params = self.optimize_with_optuna(
71+
X_train, y_train, X_valid, y_valid, self.optuna_params
5872
)
59-
60-
params = xgb_params
61-
62-
if optuna_params:
63-
params = self.optimize_with_optuna(X_train, y_train, X_valid, y_valid, optuna_params)
6473
data[DataContainer.TUNING_PARAMS] = params
6574

6675
model = xgb.XGBRegressor(**params)
@@ -69,10 +78,15 @@ def execute(self, data: DataContainer) -> DataContainer:
6978
X_train,
7079
y_train,
7180
eval_set=[(X_valid, y_valid)],
72-
early_stopping_rounds=model_configs.get("early_stopping_rounds", 100),
7381
verbose=True,
7482
)
7583

84+
end_time = time.time()
85+
elapsed_time = end_time - start_time
86+
minutes = int(elapsed_time // 60)
87+
seconds = int(elapsed_time % 60)
88+
self.logger.info(f"XGBoost model fitting took {minutes} minutes and {seconds} seconds.")
89+
7690
# Save the model to the data container
7791
data[DataContainer.MODEL] = model
7892

@@ -81,19 +95,14 @@ def execute(self, data: DataContainer) -> DataContainer:
8195
data[DataContainer.IMPORTANCE] = importance
8296

8397
# save model to disk
84-
save_path = model_configs.get("save_path")
98+
save_path = self.save_path
8599

86100
if save_path:
87101
if not save_path.endswith(".joblib"):
88102
raise ValueError("Only joblib format is supported for saving the model.")
89103
self.logger.info(f"Saving the model to {save_path}")
90104
dump(model, save_path)
91105

92-
end_time = time.time()
93-
elapsed_time = end_time - start_time
94-
minutes = int(elapsed_time // 60)
95-
seconds = int(elapsed_time % 60)
96-
self.logger.info(f"XGBoost model fitting took {minutes} minutes and {seconds} seconds.")
97106
return data
98107

99108
def optimize_with_optuna(self, X_train, y_train, X_valid, y_valid, optuna_params):

pipeline_lib/implementation/tabular/xgboost/predict.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,42 @@
33

44
from pipeline_lib.core import DataContainer
55
from pipeline_lib.core.steps import PredictStep
6+
from typing import Optional
67

78

89
class XGBoostPredictStep(PredictStep):
910
"""Obtain the predictions for XGBoost model."""
1011

11-
def execute(self, data: DataContainer) -> DataContainer:
12-
self.logger.debug("Obtaining predictions for XGBoost model.")
13-
14-
if not self.config:
15-
raise ValueError("No prediction configs found.")
16-
17-
load_path = self.config.get("load_path")
18-
if not load_path:
19-
raise ValueError("No load path found in model_configs.")
12+
def __init__(
13+
self,
14+
target: str,
15+
load_path: str,
16+
drop_columns: Optional[list[str]] = None,
17+
) -> None:
18+
self.init_logger()
2019

2120
if not load_path.endswith(".joblib"):
2221
raise ValueError("Only joblib format is supported for loading the model.")
23-
model = load(load_path)
2422

25-
model_input = data[DataContainer.CLEAN]
23+
self.target = target
24+
self.load_path = load_path
25+
self.drop_columns = drop_columns
2626

27-
if not isinstance(model_input, pd.DataFrame):
28-
raise ValueError("model_input must be a pandas DataFrame.")
27+
self.model = load(self.load_path)
2928

30-
if self.config:
31-
drop_columns = self.config.get("drop_columns")
32-
if drop_columns:
33-
model_input = model_input.drop(columns=drop_columns)
29+
def execute(self, data: DataContainer) -> DataContainer:
30+
self.logger.debug("Obtaining predictions for XGBoost model.")
3431

35-
target = self.config.get("target")
36-
if target is None:
37-
raise ValueError("Target column not found in model_configs.")
38-
data[DataContainer.TARGET] = target
32+
model_input = data[DataContainer.CLEAN]
3933

40-
predictions = model.predict(model_input.drop(columns=[target]))
41-
else:
42-
predictions = model.predict(model_input)
34+
if self.drop_columns:
35+
model_input = model_input.drop(columns=self.drop_columns)
36+
37+
predictions = self.model.predict(model_input.drop(columns=[self.target]))
4338

4439
predictions_df = pd.DataFrame(predictions, columns=["prediction"])
4540

4641
model_input[DataContainer.PREDICTIONS] = predictions_df
47-
4842
data[DataContainer.MODEL_OUTPUT] = model_input
43+
data[DataContainer.TARGET] = self.target
4944
return data

0 commit comments

Comments
 (0)