Skip to content

Commit fef94d9

Browse files
committed
split step_registry to new class
1 parent b447fc9 commit fef94d9

File tree

7 files changed

+92
-177
lines changed

7 files changed

+92
-177
lines changed

pipeline_lib/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .core.pipeline import Pipeline
22

3-
Pipeline.auto_register_steps_from_package("pipeline_lib.core.steps")
4-
Pipeline.auto_register_steps_from_package("pipeline_lib.implementation.tabular.xgboost")
3+
Pipeline.step_registry.auto_register_steps_from_package("pipeline_lib.core.steps")
4+
Pipeline.step_registry.auto_register_steps_from_package(
5+
"pipeline_lib.implementation.tabular.xgboost"
6+
)

pipeline_lib/core/pipeline.py

Lines changed: 4 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

3-
import importlib
43
import json
54
import logging
6-
import os
7-
import pkgutil
85
from typing import Optional
96

107
from pipeline_lib.core.data_container import DataContainer
8+
from pipeline_lib.core.step_registry import StepRegistry
119
from pipeline_lib.core.steps import PipelineStep
1210

1311

@@ -16,89 +14,18 @@ class Pipeline:
1614

1715
_step_registry = {}
1816
logger = logging.getLogger("Pipeline")
17+
step_registry = StepRegistry()
1918

2019
def __init__(self, initial_data: Optional[DataContainer] = None):
2120
self.steps = []
2221
self.initial_data = initial_data
2322
self.save_path = None
2423
self.load_path = None
2524

26-
@classmethod
27-
def register_step(cls, step_class):
28-
"""Register a step class using its class name."""
29-
step_name = step_class.__name__
30-
if not issubclass(step_class, PipelineStep):
31-
raise ValueError(f"{step_class} must be a subclass of PipelineStep")
32-
cls._step_registry[step_name] = step_class
33-
34-
@classmethod
35-
def get_step_class(cls, step_name):
36-
"""Retrieve a step class by name."""
37-
if step_name in cls._step_registry:
38-
return cls._step_registry[step_name]
39-
else:
40-
raise ValueError(f"Step class '{step_name}' not found in registry.")
41-
4225
def add_steps(self, steps: list[PipelineStep]):
4326
"""Add steps to the pipeline."""
4427
self.steps.extend(steps)
4528

46-
@classmethod
47-
def auto_register_steps_from_package(cls, package_name):
48-
"""
49-
Automatically registers all step classes found within a specified package.
50-
"""
51-
package = importlib.import_module(package_name)
52-
prefix = package.__name__ + "."
53-
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
54-
module = importlib.import_module(modname)
55-
for name in dir(module):
56-
attribute = getattr(module, name)
57-
if (
58-
isinstance(attribute, type)
59-
and issubclass(attribute, PipelineStep)
60-
and attribute is not PipelineStep
61-
):
62-
cls.register_step(attribute)
63-
64-
@classmethod
65-
def load_and_register_custom_steps(cls, custom_steps_path: str) -> None:
66-
"""
67-
Dynamically loads and registers step classes found in the specified directory.
68-
69-
This method scans a specified directory for Python files (excluding __init__.py),
70-
dynamically imports these files as modules, and registers all classes derived from
71-
PipelineStep found within these modules.
72-
73-
Parameters
74-
----------
75-
custom_steps_path : str
76-
The path to the directory containing custom step implementation files.
77-
"""
78-
Pipeline.logger.debug(f"Loading custom steps from: {custom_steps_path}")
79-
for filename in os.listdir(custom_steps_path):
80-
if filename.endswith(".py") and not filename.startswith("__"):
81-
filepath = os.path.join(custom_steps_path, filename)
82-
module_name = os.path.splitext(filename)[0]
83-
spec = importlib.util.spec_from_file_location(module_name, filepath)
84-
module = importlib.util.module_from_spec(spec)
85-
86-
try:
87-
spec.loader.exec_module(module)
88-
Pipeline.logger.debug(f"Successfully loaded module: {module_name}")
89-
90-
for attribute_name in dir(module):
91-
attribute = getattr(module, attribute_name)
92-
if (
93-
isinstance(attribute, type)
94-
and issubclass(attribute, PipelineStep)
95-
and attribute is not PipelineStep
96-
):
97-
Pipeline.register_step(attribute)
98-
Pipeline.logger.debug(f"Registered step class: {attribute_name}")
99-
except Exception as e:
100-
Pipeline.logger.error(f"Failed to load module: {module_name}. Error: {e}")
101-
10229
def run(self) -> DataContainer:
10330
"""Run the pipeline on the given data."""
10431

@@ -125,7 +52,7 @@ def from_json(cls, path: str) -> Pipeline:
12552

12653
custom_steps_path = config.get("custom_steps_path")
12754
if custom_steps_path:
128-
Pipeline.load_and_register_custom_steps(custom_steps_path)
55+
cls.step_registry.load_and_register_custom_steps(custom_steps_path)
12956

13057
pipeline = Pipeline()
13158

@@ -142,7 +69,7 @@ def from_json(cls, path: str) -> Pipeline:
14269
f"Creating step {step_type} with parameters: \n {json.dumps(parameters, indent=4)}"
14370
)
14471

145-
step_class = Pipeline.get_step_class(step_type)
72+
step_class = cls.step_registry.get_step_class(step_type)
14673
step = step_class(**parameters)
14774
steps.append(step)
14875

pipeline_lib/core/step_registry.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import importlib
2+
import logging
3+
import os
4+
import pkgutil
5+
6+
from pipeline_lib.core.steps import PipelineStep
7+
8+
9+
class StepRegistry:
10+
"""A helper class for managing the registry of pipeline steps."""
11+
12+
def __init__(self):
13+
self._step_registry = {}
14+
self.logger = logging.getLogger(StepRegistry.__name__)
15+
16+
def register_step(self, step_class):
17+
"""Register a step class using its class name."""
18+
step_name = step_class.__name__
19+
if not issubclass(step_class, PipelineStep):
20+
raise ValueError(f"{step_class} must be a subclass of PipelineStep")
21+
self._step_registry[step_name] = step_class
22+
23+
def get_step_class(self, step_name):
24+
"""Retrieve a step class by name."""
25+
if step_name in self._step_registry:
26+
return self._step_registry[step_name]
27+
else:
28+
raise ValueError(f"Step class '{step_name}' not found in registry.")
29+
30+
def auto_register_steps_from_package(self, package_name):
31+
"""
32+
Automatically registers all step classes found within a specified package.
33+
"""
34+
package = importlib.import_module(package_name)
35+
prefix = package.__name__ + "."
36+
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
37+
module = importlib.import_module(modname)
38+
for name in dir(module):
39+
attribute = getattr(module, name)
40+
if (
41+
isinstance(attribute, type)
42+
and issubclass(attribute, PipelineStep)
43+
and attribute is not PipelineStep
44+
):
45+
self.register_step(attribute)
46+
47+
def load_and_register_custom_steps(self, custom_steps_path: str) -> None:
48+
"""
49+
Dynamically loads and registers step classes found in the specified directory.
50+
"""
51+
self.logger.debug(f"Loading custom steps from: {custom_steps_path}")
52+
for filename in os.listdir(custom_steps_path):
53+
if filename.endswith(".py") and not filename.startswith("__"):
54+
filepath = os.path.join(custom_steps_path, filename)
55+
module_name = os.path.splitext(filename)[0]
56+
spec = importlib.util.spec_from_file_location(module_name, filepath)
57+
module = importlib.util.module_from_spec(spec)
58+
59+
try:
60+
spec.loader.exec_module(module)
61+
self.logger.debug(f"Successfully loaded module: {module_name}")
62+
63+
for attribute_name in dir(module):
64+
attribute = getattr(module, attribute_name)
65+
if (
66+
isinstance(attribute, type)
67+
and issubclass(attribute, PipelineStep)
68+
and attribute is not PipelineStep
69+
):
70+
self.register_step(attribute)
71+
self.logger.debug(f"Registered step class: {attribute_name}")
72+
except Exception as e:
73+
self.logger.error(f"Failed to load module: {module_name}. Error: {e}")

pipeline_lib/core/steps/calculate_metrics.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import numpy as np
42
from sklearn.metrics import mean_absolute_error, mean_squared_error
53

pipeline_lib/core/steps/tabular_split.py

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import Optional, Tuple
2-
3-
import pandas as pd
41
from sklearn.model_selection import train_test_split
52

63
from pipeline_lib.core import DataContainer
@@ -11,118 +8,36 @@
118
class TabularSplitStep(PipelineStep):
129
"""Split the data."""
1310

14-
def __init__(
15-
self,
16-
train_percentage: float,
17-
id_column: str,
18-
train_ids: Optional[list[str]] = None,
19-
validation_ids: Optional[list[str]] = None,
20-
) -> None:
11+
def __init__(self, train_percentage: float) -> None:
2112
"""Initialize SplitStep."""
2213
self.init_logger()
2314
self.train_percentage = train_percentage
24-
self.id_column_name = id_column
25-
self.train_ids = train_ids
26-
self.validation_ids = validation_ids
27-
28-
def _id_based_split(
29-
self,
30-
df: pd.DataFrame,
31-
train_ids: list[str],
32-
validation_ids: list[str],
33-
id_column_name: str,
34-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
35-
"""
36-
Splits the DataFrame into training and validation sets based on specified IDs.
37-
38-
Parameters
39-
----------
40-
df : pd.DataFrame
41-
The DataFrame to split.
42-
train_ids : List[str]
43-
List of IDs for the training set.
44-
validation_ids : List[str]
45-
List of IDs for the validation set.
46-
id_column_name : str
47-
The name of the column in df that contains the IDs.
48-
49-
Returns
50-
-------
51-
Tuple[pd.DataFrame, pd.DataFrame]
52-
A tuple containing the training set and the validation set.
53-
"""
54-
train_df = df[df[id_column_name].isin(train_ids)]
55-
validation_df = df[df[id_column_name].isin(validation_ids)]
56-
return train_df, validation_df
57-
58-
def _percentage_based_id_split(
59-
self, df: pd.DataFrame, train_percentage: float, id_column_name: str
60-
) -> Tuple[list[str], list[str]]:
61-
"""
62-
Splits the unique IDs into training and validation sets based on specified percentages.
6315

64-
Parameters
65-
----------
66-
df : pd.DataFrame
67-
The DataFrame containing the IDs.
68-
train_percentage : float
69-
The percentage of IDs to include in the training set.
70-
id_column_name : str
71-
The name of the column containing the IDs.
72-
73-
Returns
74-
-------
75-
Tuple[List[str], List[str]]
76-
A tuple containing lists of training and validation IDs.
77-
"""
78-
unique_ids = df[id_column_name].unique()
79-
train_ids, validation_ids = train_test_split(
80-
unique_ids, train_size=train_percentage, random_state=42
81-
)
82-
return train_ids.tolist(), validation_ids.tolist()
16+
if self.train_percentage <= 0 or self.train_percentage >= 1:
17+
raise ValueError("train_percentage must be between 0 and 1.")
8318

8419
def execute(self, data: DataContainer) -> DataContainer:
85-
"""Execute the split based on IDs."""
20+
"""Execute the random train-validation split."""
8621
self.logger.info("Splitting tabular data...")
8722

8823
df = data[DataContainer.CLEAN]
8924

90-
if self.train_percentage:
91-
if (
92-
self.train_percentage is None
93-
or self.train_percentage <= 0
94-
or self.train_percentage >= 1
95-
):
96-
raise ValueError("train_percentage must be between 0 and 1.")
97-
train_ids, validation_ids = self._percentage_based_id_split(
98-
df, self.train_percentage, self.id_column_name
99-
)
100-
101-
self.logger.info(f"Number of train ids: {len(train_ids)}")
102-
self.logger.info(f"Number of validation ids: {len(validation_ids)}")
103-
104-
train_df, validation_df = self._id_based_split(
105-
df, train_ids, validation_ids, self.id_column_name
25+
train_df, validation_df = train_test_split(
26+
df, train_size=self.train_percentage, random_state=42
10627
)
10728

10829
train_rows = len(train_df)
10930
validation_rows = len(validation_df)
11031
total_rows = train_rows + validation_rows
11132

11233
self.logger.info(
113-
f"Number of rows in training set: {len(train_df)} | {train_rows/total_rows:.2%}"
34+
f"Number of rows in training set: {train_rows} | {train_rows/total_rows:.2%}"
11435
)
11536
self.logger.info(
116-
f"Number of rows in validation set: {len(validation_df)} |"
37+
f"Number of rows in validation set: {validation_rows} |"
11738
f" {validation_rows/total_rows:.2%}"
11839
)
11940

120-
left_ids = df[~df[self.id_column_name].isin(train_ids + validation_ids)][
121-
self.id_column_name
122-
].unique()
123-
self.logger.info(f"Number of IDs left from total df: {len(left_ids)}")
124-
self.logger.debug(f"IDs left from total df: {left_ids}")
125-
12641
data[DataContainer.TRAIN] = train_df
12742
data[DataContainer.VALIDATION] = validation_df
12843

pipeline_lib/implementation/tabular/xgboost/fit_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import time
2+
from typing import Optional
23

34
import optuna
45
import xgboost as xgb
56
from joblib import dump
67
from optuna.pruners import MedianPruner
78
from sklearn.metrics import mean_absolute_error
89

9-
from typing import Optional
10-
1110
from pipeline_lib.core import DataContainer
1211
from pipeline_lib.core.steps import FitModelStep
1312

pipeline_lib/implementation/tabular/xgboost/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from typing import Optional
2+
13
import pandas as pd
24
from joblib import load
35

46
from pipeline_lib.core import DataContainer
57
from pipeline_lib.core.steps import PredictStep
6-
from typing import Optional
78

89

910
class XGBoostPredictStep(PredictStep):

0 commit comments

Comments
 (0)