Skip to content

Commit

Permalink
save runs & metrics to folder
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 25, 2024
1 parent bc9a4df commit 58bca73
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import json
import logging
from typing import Optional
import os
from datetime import datetime
from typing import Any, Optional

from pipeline_lib.core.data_container import DataContainer
from pipeline_lib.core.model_registry import ModelRegistry
Expand All @@ -24,12 +26,13 @@ def __init__(self, initial_data: Optional[DataContainer] = None):
self.save_path = None
self.load_path = None
self.model_path = None
self.config = None

def add_steps(self, steps: list[PipelineStep]):
"""Add steps to the pipeline."""
self.steps.extend(steps)

def run(self, is_train: bool) -> DataContainer:
def run(self, is_train: bool, save: bool = True) -> DataContainer:
"""Run the pipeline on the given data."""

data = DataContainer.from_pickle(self.load_path) if self.load_path else DataContainer()
Expand All @@ -48,8 +51,8 @@ def run(self, is_train: bool) -> DataContainer:
)
data = step.execute(data)

if self.save_path:
data.save(self.save_path)
if save:
self.save_run(data)

return data

Expand Down Expand Up @@ -79,6 +82,7 @@ def from_json(cls, path: str) -> Pipeline:

pipeline.load_path = config.get("load_path")
pipeline.save_path = config.get("save_path")
pipeline.config = config

steps = []

Expand Down Expand Up @@ -116,6 +120,31 @@ def from_json(cls, path: str) -> Pipeline:
pipeline.add_steps(steps)
return pipeline

def save_run(
self,
data: DataContainer,
parent_folder: str = "runs",
logs: Optional[logging.LogRecord] = None,
) -> None:
"""Save the pipeline run."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"{self.__class__.__name__}_{timestamp}"
run_folder = os.path.join(parent_folder, folder_name)

# Create the run folder
os.makedirs(run_folder, exist_ok=True)

# Save the JSON configuration
with open(os.path.join(run_folder, "pipeline_config.json"), "w") as f:
json.dump(self.config, f, indent=4, cls=CustomJSONEncoder)

# Save the training metrics
if data.metrics:
with open(os.path.join(run_folder, "metrics.json"), "w") as f:
json.dump(data.metrics, f, indent=4)

self.logger.info(f"Pipeline run saved to {run_folder}")

def __str__(self) -> str:
step_names = [f"{i + 1}. {step.__class__.__name__}" for i, step in enumerate(self.steps)]
return f"{self.__class__.__name__} with steps:\n" + "\n".join(step_names)
Expand All @@ -124,3 +153,10 @@ def __repr__(self) -> str:
"""Return an unambiguous string representation of the pipeline."""
step_names = [f"{step.__class__.__name__}()" for step in self.steps]
return f"{self.__class__.__name__}({', '.join(step_names)})"


class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, type):
return obj.__name__
return super().default(obj)

0 comments on commit 58bca73

Please sign in to comment.