Skip to content

Commit

Permalink
add data flow key for improving step data connection
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Mar 19, 2024
1 parent 0ec5f16 commit e2e99ec
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 5 deletions.
24 changes: 24 additions & 0 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,30 @@ def features(self, value: Any):
"""
self["features"] = value

@property
def flow(self) -> Any:
"""
Get the flow from the DataContainer.
Returns
-------
Any
The flow stored in the DataContainer.
"""
return self["flow"]

@flow.setter
def flow(self, value: Any):
"""
Set the flow in the DataContainer.
Parameters
----------
value
The flow to be stored in the DataContainer.
"""
self["flow"] = value

def __eq__(self, other) -> bool:
"""
Compare this DataContainer with another for equality.
Expand Down
4 changes: 2 additions & 2 deletions pipeline_lib/core/steps/calculate_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def execute(self, data: DataContainer) -> DataContainer:
"""Execute the step."""
self.logger.info("Calculating features")

df = data.clean
df = data.flow
created_features = []

if self.datetime_columns:
Expand All @@ -97,6 +97,6 @@ def execute(self, data: DataContainer) -> DataContainer:

self.logger.info(f"Created new features: {created_features}")

data.features = df
data.flow = df

return data
1 change: 1 addition & 0 deletions pipeline_lib/core/steps/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ def execute(self, data: DataContainer) -> DataContainer:
self.logger.warning(f"Column '{column}' not found in the DataFrame")

data.clean = df
data.flow = df

return data
2 changes: 1 addition & 1 deletion pipeline_lib/core/steps/explainer_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def execute(self, data: DataContainer) -> DataContainer:
if target is None:
raise ValueError("Target column not found in any parameter.")

df = data.features if data.features is not None else data.clean
df = data.flow

if len(df) > self.max_samples:
# Randomly sample a subset of data points if the dataset is larger than max_samples
Expand Down
1 change: 1 addition & 0 deletions pipeline_lib/core/steps/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def execute(self, data: DataContainer) -> DataContainer:
raise ValueError(f"Unsupported file type: {file_type}")

data.raw = df
data.flow = df

self.logger.info(f"Generated DataFrame with shape: {df.shape}")

Expand Down
2 changes: 1 addition & 1 deletion pipeline_lib/core/steps/tabular_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def execute(self, data: DataContainer) -> DataContainer:
"""Execute the random train-validation split."""
self.logger.info("Splitting tabular data...")

df = data.features if data.features is not None else data.clean
df = data.flow

train_df, validation_df = train_test_split(
df, train_size=self.train_percentage, random_state=42
Expand Down
2 changes: 1 addition & 1 deletion pipeline_lib/implementation/tabular/xgboost/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
def execute(self, data: DataContainer) -> DataContainer:
self.logger.debug("Obtaining predictions for XGBoost model.")

model_input = data.features if data.features is not None else data.clean
model_input = data.flow

if self.drop_columns:
self.logger.info(f"Dropping columns: {self.drop_columns}")
Expand Down

0 comments on commit e2e99ec

Please sign in to comment.