-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4808ab6
commit 2c96b4b
Showing
6 changed files
with
185 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from .core.pipeline import Pipeline | ||
|
||
Pipeline.step_registry.auto_register_steps_from_package("pipeline_lib.core.steps") | ||
Pipeline.step_registry.auto_register_steps_from_package( | ||
Pipeline.model_registry.auto_register_models_from_package( | ||
"pipeline_lib.implementation.tabular.xgboost" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Tuple | ||
|
||
import pandas as pd | ||
|
||
|
||
class Model(ABC): | ||
"""Base class for models.""" | ||
|
||
@abstractmethod | ||
def fit( | ||
self, | ||
X: pd.DataFrame, | ||
y: pd.Series, | ||
eval_set: Optional[List[Tuple[pd.DataFrame, pd.Series]]] = None, | ||
verbose: Optional[bool] = True, | ||
): | ||
"""Abstract method for fitting the model.""" | ||
|
||
@abstractmethod | ||
def predict(self, X: pd.DataFrame) -> pd.Series: | ||
"""Abstract method for making predictions.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import importlib | ||
import logging | ||
import pkgutil | ||
|
||
from pipeline_lib.core.model import Model | ||
|
||
|
||
class ModelClassNotFoundError(Exception): | ||
pass | ||
|
||
|
||
class ModelRegistry: | ||
def __init__(self): | ||
self._model_registry = {} | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def register_model(self, model_class: type): | ||
model_name = model_class.__name__ | ||
if not issubclass(model_class, Model): | ||
raise ValueError(f"{model_class} must be a subclass of Model") | ||
self._model_registry[model_name] = model_class | ||
|
||
def get_model_class(self, model_name: str) -> type: | ||
if model_name in self._model_registry: | ||
return self._model_registry[model_name] | ||
else: | ||
raise ModelClassNotFoundError(f"Model class '{model_name}' not found in registry.") | ||
|
||
def get_all_model_classes(self) -> dict: | ||
return self._model_registry | ||
|
||
def auto_register_models_from_package(self, package_name: str): | ||
try: | ||
package = importlib.import_module(package_name) | ||
prefix = package.__name__ + "." | ||
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, prefix): | ||
module = importlib.import_module(modname) | ||
for name in dir(module): | ||
attribute = getattr(module, name) | ||
if ( | ||
isinstance(attribute, type) | ||
and issubclass(attribute, Model) | ||
and attribute is not Model | ||
): | ||
self.register_model(attribute) | ||
except ImportError as e: | ||
self.logger.error(f"Failed to import package: {package_name}. Error: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Any | ||
|
||
import pandas as pd | ||
import xgboost as xgb | ||
|
||
from pipeline_lib.core.model import Model | ||
|
||
|
||
class XGBoostModel(Model): | ||
def __init__(self, **params): | ||
self.model = xgb.XGBRegressor(**params) | ||
|
||
def fit(self, X: pd.DataFrame, y: pd.Series, eval_set=None, verbose=True) -> Any: | ||
self.model.fit(X, y, eval_set=eval_set, verbose=verbose) | ||
return self | ||
|
||
def predict(self, X: pd.DataFrame) -> pd.Series: | ||
return self.model.predict(X) |