Skip to content

Commit

Permalink
from_json & from_yaml data containers
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 4, 2024
1 parent ad8e35c commit 48ee9e4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
71 changes: 69 additions & 2 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import sys
from typing import Optional, Union

import yaml


class DataContainer:
"""
Expand Down Expand Up @@ -193,9 +195,11 @@ def save(self, file_path: str, keys: Optional[Union[str, list[str]]] = None):
)

@classmethod
def load(cls, file_path: str, keys: Optional[Union[str, list[str]]] = None) -> DataContainer:
def from_pickle(
cls, file_path: str, keys: Optional[Union[str, list[str]]] = None
) -> DataContainer:
"""
Load data from a file and return a new instance of DataContainer.
Load data from a pickle file and return a new instance of DataContainer.
Parameters
----------
Expand All @@ -209,6 +213,10 @@ def load(cls, file_path: str, keys: Optional[Union[str, list[str]]] = None) -> D
DataContainer
A new instance of DataContainer populated with the deserialized data.
"""
# Check file is a pickle file
if not file_path.endswith(".pkl"):
raise ValueError(f"File {file_path} is not a pickle file")

with open(file_path, "rb") as file:
data = pickle.loads(file.read())

Expand All @@ -229,6 +237,65 @@ def load(cls, file_path: str, keys: Optional[Union[str, list[str]]] = None) -> D
new_container.logger.info(f"{cls.__name__} loaded from {file_path}")
return new_container

@classmethod
def from_json(cls, file_path: str) -> DataContainer:
"""
Create a new DataContainer instance from a JSON file.
Parameters
----------
file_path : str
The path to the JSON file containing the configurations.
Returns
-------
DataContainer
A new instance of DataContainer populated with the data from the JSON file.
"""
# Check file is a JSON file
if not file_path.endswith(".json"):
raise ValueError(f"File {file_path} is not a JSON file")

with open(file_path, "r") as f:
data = json.load(f)

# The loaded data is used as the initial data for the DataContainer instance
return cls(initial_data=data)

@classmethod
def from_yaml(cls, file_path: str) -> DataContainer:
"""
Create a new DataContainer instance from a YAML file.
Parameters
----------
file_path : str
The path to the YAML file containing the configurations.
Returns
-------
DataContainer
A new instance of DataContainer populated with the data from the YAML file.
Raises
------
ValueError
If the provided file is not a YAML file.
"""
# Check if the file has a .yaml or .yml extension
if not (file_path.endswith(".yaml") or file_path.endswith(".yml")):
raise ValueError(f"The file {file_path} is not a YAML file.")

try:
with open(file_path, "r") as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
# Handle cases where the file content is not valid YAML
raise ValueError(f"Error parsing YAML from {file_path}: {e}")

# The loaded data is used as the initial data for the DataContainer instance
return cls(initial_data=data)

def __eq__(self, other) -> bool:
"""
Compare this DataContainer with another for equality.
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ explainerdashboard = "^0.4.5"
optuna = "^3.5.0"
xgboost = "^2.0.3"
pyarrow = "^15.0.0"
PyYAML = "^6.0.1"

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0.0"
Expand Down

0 comments on commit 48ee9e4

Please sign in to comment.