-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6aab106
commit ad8e35c
Showing
30 changed files
with
3,932 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .data_container import DataContainer # noqa: F401 | ||
from .pipeline import Pipeline # noqa: F401 | ||
from .steps import PipelineStep # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
"""DataContainer class for storing data used in pipeline processing.""" | ||
|
||
from __future__ import annotations | ||
|
||
import json | ||
import logging | ||
import pickle | ||
import sys | ||
from typing import Optional, Union | ||
|
||
|
||
class DataContainer: | ||
""" | ||
A container for storing and manipulating data in a pipeline. | ||
Attributes | ||
---------- | ||
data : dict | ||
A dictionary to store data items. | ||
""" | ||
|
||
GENERATE_CONFIGS = "generate_configs" | ||
CLEAN_CONFIGS = "clean_configs" | ||
SPLIT_CONFIGS = "split_configs" | ||
TARGET_SCALING_CONFIGS = "target_scaling_configs" | ||
RAW = "raw" | ||
CLEAN = "clean" | ||
TRAIN = "train" | ||
VALIDATION = "validation" | ||
TEST = "test" | ||
MODEL = "model" | ||
MODEL_CONFIGS = "model_configs" | ||
MODEL_INPUT = "model_input" | ||
MODEL_OUTPUT = "model_output" | ||
METRICS = "metrics" | ||
PREDICTIONS = "predictions" | ||
EXPLAINER = "explainer" | ||
TUNING_PARAMS = "tuning_params" | ||
|
||
def __init__(self, initial_data: Optional[dict] = None): | ||
""" | ||
Initialize the DataContainer with an empty dictionary or provided data. | ||
Parameters | ||
---------- | ||
initial_data : dict, optional | ||
Initial data to populate the container. | ||
""" | ||
self.data = initial_data if initial_data is not None else {} | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.logger.debug(f"{self.__class__.__name__} initialized") | ||
|
||
def add(self, key: str, value): | ||
""" | ||
Add a new item to the container. | ||
Parameters | ||
---------- | ||
key : str | ||
The key under which the value is stored. | ||
value | ||
The data to be stored. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.data[key] = value | ||
self.logger.debug(f"Data added under key: {key}") | ||
|
||
def get(self, key: str, default=None): | ||
""" | ||
Retrieve an item from the container by its key. | ||
Parameters | ||
---------- | ||
key : str | ||
The key of the item to retrieve. | ||
default | ||
The default value to return if the key is not found. Defaults to None. | ||
Returns | ||
------- | ||
The data stored under the given key or the default value. | ||
""" | ||
return self.data.get(key, default) | ||
|
||
def __getitem__(self, key: str): | ||
""" | ||
Retrieve an item using bracket notation. | ||
Parameters | ||
---------- | ||
key : str | ||
The key of the item to retrieve. | ||
Returns | ||
------- | ||
The data stored under the given key. | ||
""" | ||
return self.get(key) | ||
|
||
def __setitem__(self, key: str, value): | ||
""" | ||
Add or update an item using bracket notation. | ||
Parameters | ||
---------- | ||
key : str | ||
The key under which the value is stored. | ||
value | ||
The data to be stored. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.add(key, value) | ||
|
||
def contains(self, key: str) -> bool: | ||
""" | ||
Check if the container contains an item with the specified key. | ||
Parameters | ||
---------- | ||
key : str | ||
The key to check in the container. | ||
Returns | ||
------- | ||
bool | ||
True if the key exists, False otherwise. | ||
""" | ||
return key in self.data | ||
|
||
def __contains__(self, key: str) -> bool: | ||
""" | ||
Enable usage of the 'in' keyword. | ||
Parameters | ||
---------- | ||
key : str | ||
The key to check in the container. | ||
Returns | ||
------- | ||
bool | ||
True if the key exists, False otherwise. | ||
""" | ||
return self.contains(key) | ||
|
||
@property | ||
def keys(self) -> list[str]: | ||
""" | ||
Return the keys of the container. | ||
Returns | ||
------- | ||
list[str] | ||
The keys of the container. | ||
""" | ||
return list(self.data.keys()) | ||
|
||
def save(self, file_path: str, keys: Optional[Union[str, list[str]]] = None): | ||
""" | ||
Serialize the container data using pickle and save it to a file. | ||
Parameters | ||
---------- | ||
file_path : str | ||
The path of the file where the serialized data should be saved. | ||
keys : Optional[Union[str, List[str]]], optional | ||
The keys of the data to be saved. If None, all data is saved. | ||
Returns | ||
------- | ||
None | ||
""" | ||
if isinstance(keys, str): | ||
keys = [keys] | ||
|
||
data_to_save = {k: self.data[k] for k in keys} if keys else self.data | ||
|
||
serialized_data = pickle.dumps(data_to_save) | ||
data_size_bytes = sys.getsizeof(serialized_data) | ||
data_size_mb = data_size_bytes / 1048576 # Convert bytes to megabytes | ||
|
||
with open(file_path, "wb") as file: | ||
file.write(serialized_data) | ||
self.logger.info( | ||
f"{self.__class__.__name__} serialized and saved to {file_path}. Size:" | ||
f" {data_size_mb:.2f} MB" | ||
) | ||
|
||
@classmethod | ||
def load(cls, file_path: str, keys: Optional[Union[str, list[str]]] = None) -> DataContainer: | ||
""" | ||
Load data from a file and return a new instance of DataContainer. | ||
Parameters | ||
---------- | ||
file_path : str | ||
The path of the file from which the serialized data should be read. | ||
keys : Optional[Union[str, List[str]]], optional | ||
The keys of the data to be loaded. If None, all data is loaded. | ||
Returns | ||
------- | ||
DataContainer | ||
A new instance of DataContainer populated with the deserialized data. | ||
""" | ||
with open(file_path, "rb") as file: | ||
data = pickle.loads(file.read()) | ||
|
||
if isinstance(keys, str): | ||
keys = [keys] | ||
|
||
if keys: | ||
data = {k: v for k, v in data.items() if k in keys} | ||
|
||
new_container = cls(initial_data=data) | ||
|
||
if keys: | ||
loaded_keys = set(new_container.keys) | ||
not_loaded_keys = set(keys) - loaded_keys if keys else set() | ||
if not_loaded_keys: | ||
new_container.logger.warning(f"Keys without values: {not_loaded_keys}") | ||
|
||
new_container.logger.info(f"{cls.__name__} loaded from {file_path}") | ||
return new_container | ||
|
||
def __eq__(self, other) -> bool: | ||
""" | ||
Compare this DataContainer with another for equality. | ||
Parameters | ||
---------- | ||
other : DataContainer | ||
Another DataContainer instance to compare with. | ||
Returns | ||
------- | ||
bool | ||
True if containers are equal, False otherwise. | ||
""" | ||
if isinstance(other, DataContainer): | ||
return self.data == other.data | ||
return False | ||
|
||
def __ne__(self, other) -> bool: | ||
""" | ||
Compare this DataContainer with another for inequality. | ||
Parameters | ||
---------- | ||
other : DataContainer | ||
Another DataContainer instance to compare with. | ||
Returns | ||
------- | ||
bool | ||
True if containers are not equal, False otherwise. | ||
""" | ||
return not self.__eq__(other) | ||
|
||
def __str__(self): | ||
""" | ||
Generate a user-friendly JSON string representation of the DataContainer. | ||
Returns | ||
------- | ||
str | ||
A JSON string describing the keys and types of contents of the DataContainer. | ||
""" | ||
data_summary = {key: type(value).__name__ for key, value in self.data.items()} | ||
return json.dumps(data_summary, indent=4) | ||
|
||
def __repr__(self): | ||
""" | ||
Generate an official string representation of the DataContainer. | ||
Returns | ||
------- | ||
str | ||
A formal string representation of the DataContainer's state. | ||
""" | ||
return f"<DataContainer({self.data})>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
from pipeline_lib.core.data_container import DataContainer | ||
from pipeline_lib.core.steps import PipelineStep | ||
|
||
|
||
class Pipeline(ABC): | ||
"""Base class for pipelines.""" | ||
|
||
def __init__(self, initial_data: Optional[DataContainer] = None): | ||
self.steps = self.define_steps() | ||
if not all(isinstance(step, PipelineStep) for step in self.steps): | ||
raise TypeError("All steps must be instances of PipelineStep") | ||
self.initial_data = initial_data | ||
|
||
def run(self, data: Optional[DataContainer] = None) -> DataContainer: | ||
"""Run the pipeline on the given data.""" | ||
if data is None: | ||
if self.initial_data is None: | ||
raise ValueError("No data given and no initial data set") | ||
self.logger.debug("No data given, using initial data") | ||
data = self.initial_data | ||
|
||
for i, step in enumerate(self.steps): | ||
self.logger.info(f"Running {step.__class__.__name__} - {i + 1} / {len(self.steps)}") | ||
data = step.execute(data) | ||
return data | ||
|
||
@abstractmethod | ||
def define_steps(self) -> list[PipelineStep]: | ||
""" | ||
Subclasses should implement this method to define their specific steps. | ||
""" | ||
|
||
def init_logger(self) -> None: | ||
"""Initialize the logger.""" | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.logger.debug(f"{self.__class__.__name__} initialized") | ||
|
||
def __str__(self) -> str: | ||
step_names = [f"{i + 1}. {step.__class__.__name__}" for i, step in enumerate(self.steps)] | ||
return f"{self.__class__.__name__} with steps:\n" + "\n".join(step_names) | ||
|
||
def __repr__(self) -> str: | ||
"""Return an unambiguous string representation of the pipeline.""" | ||
step_names = [f"{step.__class__.__name__}()" for step in self.steps] | ||
return f"{self.__class__.__name__}({', '.join(step_names)})" |
Oops, something went wrong.