Skip to content

Commit

Permalink
refactor step registry
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 13, 2024
1 parent fef94d9 commit 14e05b9
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions pipeline_lib/core/step_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,54 @@
from pipeline_lib.core.steps import PipelineStep


class StepClassNotFoundError(Exception):
pass


class StepRegistry:
"""A helper class for managing the registry of pipeline steps."""

def __init__(self):
self._step_registry = {}
self.logger = logging.getLogger(StepRegistry.__name__)
self.logger = logging.getLogger(__name__)

def register_step(self, step_class):
def register_step(self, step_class: type):
"""Register a step class using its class name."""
step_name = step_class.__name__
if not issubclass(step_class, PipelineStep):
raise ValueError(f"{step_class} must be a subclass of PipelineStep")
self._step_registry[step_name] = step_class

def get_step_class(self, step_name):
def get_step_class(self, step_name: str) -> type:
"""Retrieve a step class by name."""
if step_name in self._step_registry:
return self._step_registry[step_name]
else:
raise ValueError(f"Step class '{step_name}' not found in registry.")
raise StepClassNotFoundError(f"Step class '{step_name}' not found in registry.")

def get_all_step_classes(self) -> dict:
"""Retrieve all registered step classes."""
return self._step_registry

def auto_register_steps_from_package(self, package_name):
def auto_register_steps_from_package(self, package_name: str):
"""
Automatically registers all step classes found within a specified package.
"""
package = importlib.import_module(package_name)
prefix = package.__name__ + "."
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
module = importlib.import_module(modname)
for name in dir(module):
attribute = getattr(module, name)
if (
isinstance(attribute, type)
and issubclass(attribute, PipelineStep)
and attribute is not PipelineStep
):
self.register_step(attribute)
try:
package = importlib.import_module(package_name)
prefix = package.__name__ + "."
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix):
module = importlib.import_module(modname)
for name in dir(module):
attribute = getattr(module, name)
if (
isinstance(attribute, type)
and issubclass(attribute, PipelineStep)
and attribute is not PipelineStep
):
self.register_step(attribute)
except ImportError as e:
self.logger.error(f"Failed to import package: {package_name}. Error: {e}")

def load_and_register_custom_steps(self, custom_steps_path: str) -> None:
"""
Expand Down

0 comments on commit 14e05b9

Please sign in to comment.