Skip to content

Commit

Permalink
change target param to Generate Step
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Apr 12, 2024
1 parent ed9e121 commit 0cb6b3a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
6 changes: 2 additions & 4 deletions pipeline_lib/core/steps/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ class EncodeStep(PipelineStep):

def __init__(
self,
target: Optional[str] = None,
cardinality_threshold: float = 0.3,
feature_encoders: Optional[dict] = None,
) -> None:
"""Initialize EncodeStep."""
self.init_logger()
self.target = target
self.cardinality_threshold = cardinality_threshold
self.feature_encoders = feature_encoders or {}

Expand All @@ -42,10 +40,10 @@ def execute(self, data: DataContainer) -> DataContainer:
self.logger.info("Encoding data")
df = data.flow

if not data.target and not self.target:
if not data.target:
raise ValueError("Target column not found in any parameter before encoding.")

target_column_name = self.target or data.target
target_column_name = data.target

categorical_features, numeric_features = self._get_feature_types(df, target_column_name)
low_cardinality_features, high_cardinality_features = self._split_categorical_features(
Expand Down
15 changes: 6 additions & 9 deletions pipeline_lib/core/steps/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class FitModelStep(PipelineStep):
def __init__(
self,
model_class: Type[Model],
target: str,
model_params: Optional[dict] = None,
drop_columns: Optional[list[str]] = None,
optuna_params: Optional[dict] = None,
Expand All @@ -99,7 +98,6 @@ def __init__(
super().__init__()
self.init_logger()
self.model_class = model_class
self.target = target
self.model_params = model_params or {}
self.drop_columns = drop_columns or []
self.optuna_params = optuna_params
Expand All @@ -109,7 +107,7 @@ def execute(self, data: DataContainer) -> DataContainer:
self.logger.info(f"Fitting the {self.model_class.__name__} model")

df_train, df_valid = self._prepare_data(data)
X_train, y_train, X_valid, y_valid = self._extract_target(df_train, df_valid)
X_train, y_train, X_valid, y_valid = self._extract_target(df_train, df_valid, data.target)

model_params = self.model_params

Expand All @@ -126,7 +124,6 @@ def execute(self, data: DataContainer) -> DataContainer:
model.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], verbose=True)

data.model = model
data.target = self.target
data._drop_columns = self.drop_columns

if self.save_path:
Expand All @@ -146,13 +143,13 @@ def _prepare_data(self, data: DataContainer) -> tuple:
return df_train, df_valid

def _extract_target(
self, df_train: pd.DataFrame, df_valid: pd.DataFrame
self, df_train: pd.DataFrame, df_valid: pd.DataFrame, target: str
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
"""Extract target column from the dataframes, to be used in model fitting."""
X_train = df_train.drop(columns=[self.target])
y_train = df_train[self.target]
X_train = df_train.drop(columns=[target])
y_train = df_train[target]

X_valid = df_valid.drop(columns=[self.target])
y_valid = df_valid[self.target]
X_valid = df_valid.drop(columns=[target])
y_valid = df_valid[target]

return X_train, y_train, X_valid, y_valid
8 changes: 7 additions & 1 deletion pipeline_lib/core/steps/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ class GenerateStep(PipelineStep):
used_for_training = True

def __init__(
self, train_path: Optional[str] = None, predict_path: Optional[str] = None, **kwargs
self,
target: str,
train_path: Optional[str] = None,
predict_path: Optional[str] = None,
**kwargs,
):
self.init_logger()
self.target = target
self.train_path = train_path
self.predict_path = predict_path
self.kwargs = kwargs
Expand Down Expand Up @@ -58,6 +63,7 @@ def execute(self, data: DataContainer) -> DataContainer:

data.raw = df
data.flow = df
data.target = self.target

self.logger.info(f"Generated DataFrame with shape: {df.shape}")

Expand Down

0 comments on commit 0cb6b3a

Please sign in to comment.