|
1 | 1 | # Databricks notebook source
|
2 | 2 |
|
| 3 | +import subprocess |
| 4 | + |
3 | 5 | import mlflow
|
4 | 6 | from lightgbm import LGBMClassifier
|
5 | 7 | from mlflow.models import infer_signature
|
6 | 8 | from pyspark.sql import SparkSession
|
7 | 9 | from sklearn.compose import ColumnTransformer
|
8 |
| -from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score |
| 10 | +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score |
9 | 11 | from sklearn.model_selection import GridSearchCV
|
10 | 12 | from sklearn.pipeline import Pipeline
|
11 | 13 | from sklearn.preprocessing import OneHotEncoder
|
|
16 | 18 | light_gbm_config,
|
17 | 19 | )
|
18 | 20 |
|
| 21 | + |
| 22 | +def get_git_info(): |
| 23 | + try: |
| 24 | + sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() |
| 25 | + branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip() |
| 26 | + return {"git_sha": sha, "branch": branch} |
| 27 | + except Exception: |
| 28 | + return {"git_sha": "unknown", "branch": "unknown"} |
| 29 | + |
| 30 | + |
19 | 31 | mlflow.set_tracking_uri("databricks://dbc-643c4c2b-d6c9")
|
20 | 32 | mlflow.set_registry_uri("databricks-uc://dbc-643c4c2b-d6c9") # It must be -uc for registering models to Unity Catalog
|
21 | 33 |
|
|
50 | 62 | )
|
51 | 63 |
|
52 | 64 | # Create the pipeline with preprocessing and the LightGBM regressor
|
53 |
| -pipeline = Pipeline(steps=[("classifier", LGBMClassifier(**parameters))]) |
| 65 | +pipeline = Pipeline(steps=[("onehot", preprocessor), ("classifier", LGBMClassifier(**parameters))]) |
54 | 66 |
|
55 | 67 | # Define parameter grid for hyperparameter tuning
|
56 | 68 | param_grid = {
|
|
60 | 72 | }
|
61 | 73 |
|
62 | 74 | # Perform hyperparameter tuning with GridSearchCV
|
63 |
| -grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="roc_auc", n_jobs=-1) |
| 75 | +grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="f1", n_jobs=-1) |
64 | 76 |
|
65 | 77 | # COMMAND ----------
|
66 | 78 | mlflow.set_experiment(experiment_name="/Shared/ad-click")
|
67 |
| -git_sha = "ffa63b430205ff7" |
68 | 79 |
|
69 | 80 | # Start an MLflow run to track the training process
|
70 | 81 | with mlflow.start_run(
|
71 |
| - tags={"git_sha": f"{git_sha}", "branch": "week2"}, |
| 82 | + tags=get_git_info(), |
72 | 83 | ) as run:
|
73 | 84 | run_id = run.info.run_id
|
74 | 85 |
|
|
82 | 93 | precision = precision_score(y_test, y_pred)
|
83 | 94 | recall = recall_score(y_test, y_pred)
|
84 | 95 | roc_auc = roc_auc_score(y_test, y_pred)
|
| 96 | + accuracy = accuracy_score(y_test, y_pred) |
85 | 97 |
|
86 | 98 | # Log parameters, metrics, and the model to MLflow
|
87 | 99 | mlflow.log_param("model_type", "LightGBM with preprocessing")
|
88 | 100 | mlflow.log_params(best_params)
|
89 |
| - mlflow.log_metrics({"f1": f1, "precision": precision, "recall": recall, "roc_auc": roc_auc}) |
90 |
| - signature = infer_signature(model_input=X_train, model_output=y_pred) |
| 101 | + mlflow.log_metrics({"f1": f1, "accuracy": accuracy, "precision": precision, "recall": recall, "roc_auc": roc_auc}) |
| 102 | + signature = infer_signature(model_input=X_test, model_output=y_pred) |
91 | 103 |
|
92 | 104 | dataset = mlflow.data.from_spark(train_set_spark, table_name=f"{catalog_name}.{schema_name}.train_set", version="0")
|
93 | 105 | mlflow.log_input(dataset, context="training")
|
94 | 106 |
|
95 |
| - mlflow.sklearn.log_model(sk_model=pipeline, artifact_path="lightgbm-pipeline-model", signature=signature) |
| 107 | + mlflow.sklearn.log_model(sk_model=best_pipeline, artifact_path="lightgbm-pipeline-model", signature=signature) |
96 | 108 |
|
97 | 109 |
|
98 | 110 | # COMMAND ----------
|
99 | 111 | model_version = mlflow.register_model(
|
100 | 112 | model_uri=f"runs:/{run_id}/lightgbm-pipeline-model",
|
101 | 113 | name=f"{catalog_name}.{schema_name}.ad_click_model_basic",
|
102 |
| - tags={"git_sha": f"{git_sha}"}, |
| 114 | + tags=get_git_info(), |
103 | 115 | )
|
104 | 116 |
|
105 | 117 | # COMMAND ----------
|
|
0 commit comments