Skip to content

Commit 83d6f23

Browse files
committed
refactor: simplify data configuration and preprocessing steps
1 parent b3b3830 commit 83d6f23

File tree

3 files changed

+36
-43
lines changed

3 files changed

+36
-43
lines changed

src/mlops_with_databricks/data_preprocessing/dataclasses.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,14 @@ class AdClickDataConfig:
3838
class ProcessedAdClickDataConfig:
3939
"""Dataclass for the Processed Ad Click Data configuration."""
4040

41-
target: str = "cat__click_0"
41+
target: str = "click"
4242
num_features: tuple[str] = ("num__age",)
4343
cat_features: tuple[str] = (
44-
"cat__gender_Female",
45-
"cat__gender_Male",
46-
"cat__gender_Non-Binary",
47-
"cat__device_type_Desktop",
48-
"cat__device_type_Mobile",
49-
"cat__device_type_Tablet",
50-
"cat__ad_position_Bottom",
51-
"cat__ad_position_Side",
52-
"cat__ad_position_Top",
53-
"cat__browsing_history_Education",
54-
"cat__browsing_history_Entertainment",
55-
"cat__browsing_history_News",
56-
"cat__browsing_history_Shopping",
57-
"cat__browsing_history_Social_Media",
58-
"cat__time_of_day_Afternoon",
59-
"cat__time_of_day_Evening",
60-
"cat__time_of_day_Morning",
61-
"cat__time_of_day_Night",
44+
"cat__gender",
45+
"cat__device_type",
46+
"cat__ad_position",
47+
"cat__browsing_history",
48+
"cat__time_of_day",
6249
)
6350

6451

@@ -72,9 +59,9 @@ class DatabricksConfig:
7259

7360

7461
class LightGBMConfig(TypedDict):
75-
learning_rate: str = 0.001
76-
n_estimators: str = 200
77-
max_depth: str = 10
62+
learning_rate: float
63+
n_estimators: int
64+
max_depth: int
7865

7966

8067
light_gbm_config = LightGBMConfig(learning_rate=0.001, n_estimators=200, max_depth=10)

src/mlops_with_databricks/data_preprocessing/preprocess.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from sklearn.impute import SimpleImputer
1111
from sklearn.model_selection import train_test_split
1212
from sklearn.pipeline import Pipeline
13-
from sklearn.preprocessing import OneHotEncoder
1413

1514
from mlops_with_databricks.data_preprocessing.dataclasses import AdClickDataColumns, AdClickDataConfig, DatabricksConfig
1615

@@ -32,18 +31,11 @@ def load_data(self, filepath: str | Path) -> None:
3231
def from_pandas(cls, pandas_df: pd.DataFrame) -> "DataProcessor":
3332
"""Create a DataProcessor object from a pandas DataFrame."""
3433
instance = cls()
35-
instance.X = None
36-
instance.y = None
37-
instance.preprocessor = None
3834
instance.df = pandas_df
3935
return instance
4036

4137
def preprocess_data(self) -> None:
42-
"""Preprocess the data. Fill missing values, cast types, and split features and target.
43-
44-
Returns:
45-
tuple[pd.DataFrame, pd.Series]: Preprocessed features and target.
46-
"""
38+
"""Preprocess the data. Fill missing values, cast types, and split features and target."""
4739
self.df = self.df.drop(columns=[AdClickDataColumns.id, AdClickDataColumns.full_name])
4840
self.df = self.fill_missing_values(self.df)
4941
self.df[AdClickDataColumns.browsing_history] = self.df[AdClickDataColumns.browsing_history].str.replace(
@@ -56,17 +48,19 @@ def preprocess_data(self) -> None:
5648
categorical_transformer = Pipeline(
5749
steps=[
5850
("imputer", SimpleImputer(strategy="most_frequent")),
59-
("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False)),
51+
# ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False)),
6052
]
6153
)
6254

6355
self.preprocessor = ColumnTransformer(
6456
transformers=[
6557
("num", numeric_transformer, list(AdClickDataConfig.num_features)),
66-
("cat", categorical_transformer, list(AdClickDataConfig.cat_features) + [AdClickDataConfig.target]),
58+
("cat", categorical_transformer, list(AdClickDataConfig.cat_features)),
6759
]
6860
).set_output(transform="pandas")
69-
self.df = self.preprocessor.fit_transform(self.df)
61+
preprocessed_features = self.preprocessor.fit_transform(self.df)
62+
preprocessed_features["click"] = self.df[AdClickDataColumns.click].astype("int64")
63+
self.df = preprocessed_features
7064

7165
@staticmethod
7266
def fill_missing_values(df: pd.DataFrame) -> pd.DataFrame:

src/mlops_with_databricks/training/train.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Databricks notebook source
22

3+
import subprocess
4+
35
import mlflow
46
from lightgbm import LGBMClassifier
57
from mlflow.models import infer_signature
68
from pyspark.sql import SparkSession
79
from sklearn.compose import ColumnTransformer
8-
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
10+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
911
from sklearn.model_selection import GridSearchCV
1012
from sklearn.pipeline import Pipeline
1113
from sklearn.preprocessing import OneHotEncoder
@@ -16,6 +18,16 @@
1618
light_gbm_config,
1719
)
1820

21+
22+
def get_git_info():
23+
try:
24+
sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
25+
branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
26+
return {"git_sha": sha, "branch": branch}
27+
except Exception:
28+
return {"git_sha": "unknown", "branch": "unknown"}
29+
30+
1931
mlflow.set_tracking_uri("databricks://dbc-643c4c2b-d6c9")
2032
mlflow.set_registry_uri("databricks-uc://dbc-643c4c2b-d6c9") # It must be -uc for registering models to Unity Catalog
2133

@@ -50,7 +62,7 @@
5062
)
5163

5264
# Create the pipeline with preprocessing and the LightGBM regressor
53-
pipeline = Pipeline(steps=[("classifier", LGBMClassifier(**parameters))])
65+
pipeline = Pipeline(steps=[("onehot", preprocessor), ("classifier", LGBMClassifier(**parameters))])
5466

5567
# Define parameter grid for hyperparameter tuning
5668
param_grid = {
@@ -60,15 +72,14 @@
6072
}
6173

6274
# Perform hyperparameter tuning with GridSearchCV
63-
grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="roc_auc", n_jobs=-1)
75+
grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="f1", n_jobs=-1)
6476

6577
# COMMAND ----------
6678
mlflow.set_experiment(experiment_name="/Shared/ad-click")
67-
git_sha = "ffa63b430205ff7"
6879

6980
# Start an MLflow run to track the training process
7081
with mlflow.start_run(
71-
tags={"git_sha": f"{git_sha}", "branch": "week2"},
82+
tags=get_git_info(),
7283
) as run:
7384
run_id = run.info.run_id
7485

@@ -82,24 +93,25 @@
8293
precision = precision_score(y_test, y_pred)
8394
recall = recall_score(y_test, y_pred)
8495
roc_auc = roc_auc_score(y_test, y_pred)
96+
accuracy = accuracy_score(y_test, y_pred)
8597

8698
# Log parameters, metrics, and the model to MLflow
8799
mlflow.log_param("model_type", "LightGBM with preprocessing")
88100
mlflow.log_params(best_params)
89-
mlflow.log_metrics({"f1": f1, "precision": precision, "recall": recall, "roc_auc": roc_auc})
90-
signature = infer_signature(model_input=X_train, model_output=y_pred)
101+
mlflow.log_metrics({"f1": f1, "accuracy": accuracy, "precision": precision, "recall": recall, "roc_auc": roc_auc})
102+
signature = infer_signature(model_input=X_test, model_output=y_pred)
91103

92104
dataset = mlflow.data.from_spark(train_set_spark, table_name=f"{catalog_name}.{schema_name}.train_set", version="0")
93105
mlflow.log_input(dataset, context="training")
94106

95-
mlflow.sklearn.log_model(sk_model=pipeline, artifact_path="lightgbm-pipeline-model", signature=signature)
107+
mlflow.sklearn.log_model(sk_model=best_pipeline, artifact_path="lightgbm-pipeline-model", signature=signature)
96108

97109

98110
# COMMAND ----------
99111
model_version = mlflow.register_model(
100112
model_uri=f"runs:/{run_id}/lightgbm-pipeline-model",
101113
name=f"{catalog_name}.{schema_name}.ad_click_model_basic",
102-
tags={"git_sha": f"{git_sha}"},
114+
tags=get_git_info(),
103115
)
104116

105117
# COMMAND ----------

0 commit comments

Comments
 (0)