Skip to content

Commit e2e99ec

Browse files
committed
add data flow key for improving step data connection
1 parent 0ec5f16 commit e2e99ec

File tree

7 files changed

+31
-5
lines changed

7 files changed

+31
-5
lines changed

pipeline_lib/core/data_container.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,30 @@ def features(self, value: Any):
592592
"""
593593
self["features"] = value
594594

595+
@property
596+
def flow(self) -> Any:
597+
"""
598+
Get the flow from the DataContainer.
599+
600+
Returns
601+
-------
602+
Any
603+
The flow stored in the DataContainer.
604+
"""
605+
return self["flow"]
606+
607+
@flow.setter
608+
def flow(self, value: Any):
609+
"""
610+
Set the flow in the DataContainer.
611+
612+
Parameters
613+
----------
614+
value
615+
The flow to be stored in the DataContainer.
616+
"""
617+
self["flow"] = value
618+
595619
def __eq__(self, other) -> bool:
596620
"""
597621
Compare this DataContainer with another for equality.

pipeline_lib/core/steps/calculate_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def execute(self, data: DataContainer) -> DataContainer:
7070
"""Execute the step."""
7171
self.logger.info("Calculating features")
7272

73-
df = data.clean
73+
df = data.flow
7474
created_features = []
7575

7676
if self.datetime_columns:
@@ -97,6 +97,6 @@ def execute(self, data: DataContainer) -> DataContainer:
9797

9898
self.logger.info(f"Created new features: {created_features}")
9999

100-
data.features = df
100+
data.flow = df
101101

102102
return data

pipeline_lib/core/steps/clean.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,6 @@ def execute(self, data: DataContainer) -> DataContainer:
112112
self.logger.warning(f"Column '{column}' not found in the DataFrame")
113113

114114
data.clean = df
115+
data.flow = df
115116

116117
return data

pipeline_lib/core/steps/explainer_dashboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def execute(self, data: DataContainer) -> DataContainer:
2525
if target is None:
2626
raise ValueError("Target column not found in any parameter.")
2727

28-
df = data.features if data.features is not None else data.clean
28+
df = data.flow
2929

3030
if len(df) > self.max_samples:
3131
# Randomly sample a subset of data points if the dataset is larger than max_samples

pipeline_lib/core/steps/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def execute(self, data: DataContainer) -> DataContainer:
3434
raise ValueError(f"Unsupported file type: {file_type}")
3535

3636
data.raw = df
37+
data.flow = df
3738

3839
self.logger.info(f"Generated DataFrame with shape: {df.shape}")
3940

pipeline_lib/core/steps/tabular_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def execute(self, data: DataContainer) -> DataContainer:
1919
"""Execute the random train-validation split."""
2020
self.logger.info("Splitting tabular data...")
2121

22-
df = data.features if data.features is not None else data.clean
22+
df = data.flow
2323

2424
train_df, validation_df = train_test_split(
2525
df, train_size=self.train_percentage, random_state=42

pipeline_lib/implementation/tabular/xgboost/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
def execute(self, data: DataContainer) -> DataContainer:
3131
self.logger.debug("Obtaining predictions for XGBoost model.")
3232

33-
model_input = data.features if data.features is not None else data.clean
33+
model_input = data.flow
3434

3535
if self.drop_columns:
3636
self.logger.info(f"Dropping columns: {self.drop_columns}")

0 commit comments

Comments
 (0)