Skip to content

Commit 14e05b9

Browse files
committed
refactor step registry
1 parent fef94d9 commit 14e05b9

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

pipeline_lib/core/step_registry.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,54 @@
66
from pipeline_lib.core.steps import PipelineStep
77

88

9+
class StepClassNotFoundError(Exception):
10+
pass
11+
12+
913
class StepRegistry:
1014
"""A helper class for managing the registry of pipeline steps."""
1115

1216
def __init__(self):
1317
self._step_registry = {}
14-
self.logger = logging.getLogger(StepRegistry.__name__)
18+
self.logger = logging.getLogger(__name__)
1519

16-
def register_step(self, step_class):
20+
def register_step(self, step_class: type):
1721
"""Register a step class using its class name."""
1822
step_name = step_class.__name__
1923
if not issubclass(step_class, PipelineStep):
2024
raise ValueError(f"{step_class} must be a subclass of PipelineStep")
2125
self._step_registry[step_name] = step_class
2226

23-
def get_step_class(self, step_name):
27+
def get_step_class(self, step_name: str) -> type:
2428
"""Retrieve a step class by name."""
2529
if step_name in self._step_registry:
2630
return self._step_registry[step_name]
2731
else:
28-
raise ValueError(f"Step class '{step_name}' not found in registry.")
32+
raise StepClassNotFoundError(f"Step class '{step_name}' not found in registry.")
33+
34+
def get_all_step_classes(self) -> dict:
35+
"""Retrieve all registered step classes."""
36+
return self._step_registry
2937

30-
def auto_register_steps_from_package(self, package_name):
38+
def auto_register_steps_from_package(self, package_name: str):
3139
"""
3240
Automatically registers all step classes found within a specified package.
3341
"""
34-
package = importlib.import_module(package_name)
35-
prefix = package.__name__ + "."
36-
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
37-
module = importlib.import_module(modname)
38-
for name in dir(module):
39-
attribute = getattr(module, name)
40-
if (
41-
isinstance(attribute, type)
42-
and issubclass(attribute, PipelineStep)
43-
and attribute is not PipelineStep
44-
):
45-
self.register_step(attribute)
42+
try:
43+
package = importlib.import_module(package_name)
44+
prefix = package.__name__ + "."
45+
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
46+
module = importlib.import_module(modname)
47+
for name in dir(module):
48+
attribute = getattr(module, name)
49+
if (
50+
isinstance(attribute, type)
51+
and issubclass(attribute, PipelineStep)
52+
and attribute is not PipelineStep
53+
):
54+
self.register_step(attribute)
55+
except ImportError as e:
56+
self.logger.error(f"Failed to import package: {package_name}. Error: {e}")
4657

4758
def load_and_register_custom_steps(self, custom_steps_path: str) -> None:
4859
"""

0 commit comments

Comments
 (0)