Skip to content

Commit 58bca73

Browse files
committed
save runs & metrics to folder
1 parent bc9a4df commit 58bca73

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

pipeline_lib/core/pipeline.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import json
44
import logging
5-
from typing import Optional
5+
import os
6+
from datetime import datetime
7+
from typing import Any, Optional
68

79
from pipeline_lib.core.data_container import DataContainer
810
from pipeline_lib.core.model_registry import ModelRegistry
@@ -24,12 +26,13 @@ def __init__(self, initial_data: Optional[DataContainer] = None):
2426
self.save_path = None
2527
self.load_path = None
2628
self.model_path = None
29+
self.config = None
2730

2831
def add_steps(self, steps: list[PipelineStep]):
2932
"""Add steps to the pipeline."""
3033
self.steps.extend(steps)
3134

32-
def run(self, is_train: bool) -> DataContainer:
35+
def run(self, is_train: bool, save: bool = True) -> DataContainer:
3336
"""Run the pipeline on the given data."""
3437

3538
data = DataContainer.from_pickle(self.load_path) if self.load_path else DataContainer()
@@ -48,8 +51,8 @@ def run(self, is_train: bool) -> DataContainer:
4851
)
4952
data = step.execute(data)
5053

51-
if self.save_path:
52-
data.save(self.save_path)
54+
if save:
55+
self.save_run(data)
5356

5457
return data
5558

@@ -79,6 +82,7 @@ def from_json(cls, path: str) -> Pipeline:
7982

8083
pipeline.load_path = config.get("load_path")
8184
pipeline.save_path = config.get("save_path")
85+
pipeline.config = config
8286

8387
steps = []
8488

@@ -116,6 +120,31 @@ def from_json(cls, path: str) -> Pipeline:
116120
pipeline.add_steps(steps)
117121
return pipeline
118122

123+
def save_run(
124+
self,
125+
data: DataContainer,
126+
parent_folder: str = "runs",
127+
logs: Optional[logging.LogRecord] = None,
128+
) -> None:
129+
"""Save the pipeline run."""
130+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
131+
folder_name = f"{self.__class__.__name__}_{timestamp}"
132+
run_folder = os.path.join(parent_folder, folder_name)
133+
134+
# Create the run folder
135+
os.makedirs(run_folder, exist_ok=True)
136+
137+
# Save the JSON configuration
138+
with open(os.path.join(run_folder, "pipeline_config.json"), "w") as f:
139+
json.dump(self.config, f, indent=4, cls=CustomJSONEncoder)
140+
141+
# Save the training metrics
142+
if data.metrics:
143+
with open(os.path.join(run_folder, "metrics.json"), "w") as f:
144+
json.dump(data.metrics, f, indent=4)
145+
146+
self.logger.info(f"Pipeline run saved to {run_folder}")
147+
119148
def __str__(self) -> str:
120149
step_names = [f"{i + 1}. {step.__class__.__name__}" for i, step in enumerate(self.steps)]
121150
return f"{self.__class__.__name__} with steps:\n" + "\n".join(step_names)
@@ -124,3 +153,10 @@ def __repr__(self) -> str:
124153
"""Return an unambiguous string representation of the pipeline."""
125154
step_names = [f"{step.__class__.__name__}()" for step in self.steps]
126155
return f"{self.__class__.__name__}({', '.join(step_names)})"
156+
157+
158+
class CustomJSONEncoder(json.JSONEncoder):
159+
def default(self, obj: Any) -> Any:
160+
if isinstance(obj, type):
161+
return obj.__name__
162+
return super().default(obj)

0 commit comments

Comments
 (0)