Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CI/CD workflow for Databricks deployment #4

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/cd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: MLOPs with Databricks

Comment on lines +1 to +2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add security hardening measures

The workflow needs additional security configurations:

  1. Declare minimum required permissions
  2. Consider using OIDC federation instead of static token
 name: MLOPs with Databricks
 
+permissions:
+  contents: read
+  id-token: write  # Required for OIDC
 
+# For OIDC authentication
+# Remove DATABRICKS_TOKEN and use:
+# configure-databricks-auth:
+#   uses: databricks/configure-auth@main
+#   with:
+#     databricks-environment: prod

Committable suggestion skipped: line range outside the PR's diff.

on:
push:
branches:
- 'main'
tags:
- '[0-9]+.[0-9]+.[0-9]+'


jobs:
setup-validate:
name: Set Up Environment
runs-on: ubuntu-latest
Comment on lines +11 to +14
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Enhance job configuration with deployment safeguards

The job configuration needs several improvements:

  1. The job name "setup-validate" doesn't reflect its deployment purpose
  2. Missing environment protection for production deployment
  3. Missing concurrency control to prevent parallel deployments
 jobs:
-  setup-validate:
-    name: Set Up Environment
+  deploy:
+    name: Deploy to Databricks
     runs-on: ubuntu-latest
+    environment:
+      name: production
+      url: ${{ vars.DATABRICKS_WORKSPACE_URL }}
+    concurrency:
+      group: databricks-${{ github.ref }}
+      cancel-in-progress: false
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
jobs:
setup-validate:
name: Set Up Environment
runs-on: ubuntu-latest
jobs:
deploy:
name: Deploy to Databricks
runs-on: ubuntu-latest
environment:
name: production
url: ${{ vars.DATABRICKS_WORKSPACE_URL }}
concurrency:
group: databricks-${{ github.ref }}
cancel-in-progress: false


steps:
- name: Checkout Source Code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
ref: ${{ github.ref_name }}

- name: Set Up Python
uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa
with:
python-version: 3.11

- name: Install UV
uses: astral-sh/setup-uv@2e657c127d5b1635d5a8e3fa40e0ac50a5bf6992

Comment on lines +27 to +29
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused UV installation step

The UV package manager is installed but never used in the workflow. Either remove this step or utilize UV for dependency management.

- name: Install Databricks CLI
uses: databricks/setup-cli@948d7379a31615a4c8e9ccbbc5445a12d6b32736
with:
version: 0.221.1

- name: Deploy to Databricks
env:
DATABRICKS_BUNDLE_ENV: prod # bundle target
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
run: databricks bundle deploy --var="git_sha=${{ github.sha }}"
Comment on lines +35 to +39
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Enhance deployment step with validation and error handling

The deployment step needs several improvements:

  1. Missing pre-deployment validation
  2. No error handling or retries
  3. Missing workspace URL configuration
  4. No post-deployment verification
+      - name: Validate Bundle
+        env:
+          DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
+        run: databricks bundle validate
+
       - name: Deploy to Databricks
         env:
           DATABRICKS_BUNDLE_ENV: prod # bundle target
           DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
-        run: databricks bundle deploy --var="git_sha=${{ github.sha }}"
+          DATABRICKS_HOST: ${{ vars.DATABRICKS_WORKSPACE_URL }}
+        run: |
+          for i in {1..3}; do
+            if databricks bundle deploy \
+              --var="git_sha=${{ github.sha }}" \
+              --var="environment=prod"; then
+              exit 0
+            fi
+            echo "Deployment attempt $i failed. Retrying..."
+            sleep 10
+          done
+          exit 1
+
+      - name: Verify Deployment
+        if: success()
+        env:
+          DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
+          DATABRICKS_HOST: ${{ vars.DATABRICKS_WORKSPACE_URL }}
+        run: |
+          # Verify the deployment status
+          databricks bundle validate --target prod
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- name: Deploy to Databricks
env:
DATABRICKS_BUNDLE_ENV: prod # bundle target
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
run: databricks bundle deploy --var="git_sha=${{ github.sha }}"
- name: Validate Bundle
env:
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
run: databricks bundle validate
- name: Deploy to Databricks
env:
DATABRICKS_BUNDLE_ENV: prod # bundle target
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
DATABRICKS_HOST: ${{ vars.DATABRICKS_WORKSPACE_URL }}
run: |
for i in {1..3}; do
if databricks bundle deploy \
--var="git_sha=${{ github.sha }}" \
--var="environment=prod"; then
exit 0
fi
echo "Deployment attempt $i failed. Retrying..."
sleep 10
done
exit 1
- name: Verify Deployment
if: success()
env:
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
DATABRICKS_HOST: ${{ vars.DATABRICKS_WORKSPACE_URL }}
run: |
# Verify the deployment status
databricks bundle validate --target prod

124 changes: 115 additions & 9 deletions databricks.yml
Original file line number Diff line number Diff line change
@@ -1,19 +1,125 @@
# This is a Databricks asset bundle definition for module-code-examples.
# The Databricks extension requires databricks.yml configuration file.
# See https://docs.databricks.com/dev-tools/bundles/index.html for documentation.

bundle:
name: armak-bundle
cluster_id: 1020-170922-owfkgqkz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Hardcoded cluster ID should be replaced with job cluster configuration

The cluster ID is only used once in the codebase, and interestingly, the same file already demonstrates a better practice by using a job cluster configuration (job_clusters) with a logical key "ad-click-cluster". You should:

  • Remove the hardcoded cluster_id: 1020-170922-owfkgqkz
  • Use the job cluster configuration pattern that's already defined in the file, which provides better maintainability and environment independence
  • Define cluster requirements (like spark_version and data_security_mode) instead of binding to a specific cluster ID
🔗 Analysis chain

Consider using cluster policies instead of hardcoded cluster IDs

Hardcoding cluster IDs in configuration files can be problematic as clusters may be terminated or recreated, breaking the workflow. Consider using cluster policies or templates to define cluster configurations consistently across environments.

Let's verify if this cluster ID is used elsewhere in the codebase:


Let me gather more context about the Databricks configuration patterns in the codebase:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for the cluster ID usage across the codebase
rg "1020-170922-owfkgqkz"

Length of output: 75


Script:

#!/bin/bash
# Look for other Databricks configuration files and cluster configurations
fd -e yml -e yaml | xargs rg -l "cluster"

# Search for any cluster policy references
rg -i "cluster.?policy"

# Check if there are any template or variable patterns for cluster configurations
rg "cluster_id:|cluster:|CLUSTER" -A 2 -B 2

Length of output: 538


artifacts:
default:
type: whl
build: python -m build
path: .

variables:
root_path:
description: root_path for the target
default: /Shared/.bundle/${bundle.target}/${bundle.name}
git_sha:
description: git_sha
default: ${bundle.git.commit}
git_branch:
description: git_sha
default: ${bundle.git.branch}


resources:
jobs:
ad-cick:
name: ad-click-workflow
schedule:
quartz_cron_expression: "0 0 6 ? * MON"
timezone_id: "Europe/Warsaw"
pause_status: UNPAUSED
tags:
project_name: "ad-click"
job_clusters:
- job_cluster_key: "ad-click-cluster"
new_cluster:
spark_version: "15.4.x-scala2.12"
data_security_mode: "SINGLE_USER"
node_type_id: "i3.xlarge"
driver_node_type_id: "i3.xlarge"
autoscale:
min_workers: 1
max_workers: 1

tasks:
- task_key: "preprocessing"
job_cluster_key: "ad-click-cluster"
spark_python_task:
python_file: "src/mlops_with_databricks/pipeline/preprocess.py"
libraries:
- whl: ./dist/*.whl
- task_key: if_refreshed
condition_task:
op: "EQUAL_TO"
left: "{{tasks.preprocessing.values.refreshed}}"
right: "1"
depends_on:
- task_key: "preprocessing"
- task_key: "train_model"
depends_on:
- task_key: "if_refreshed"
outcome: "true"
job_cluster_key: "ad-click-cluster"
spark_python_task:
python_file: "src/mlops_with_databricks/pipeline/train_model.py"
parameters:
- "--git_sha"
- ${var.git_sha}
- "--git_branch"
- ${var.git_branch}
- "--job_run_id"
- "{{job.id}}"
Comment on lines +65 to +71
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Parameters are validated but documentation is missing

The script validates the parameters through argparse with required=True flag, ensuring they cannot be empty. However, there's no documentation about the expected formats. Consider:

  • git_sha: Add documentation specifying the expected Git commit hash format
  • git_branch: Document the branch name format/constraints
  • job_run_id: Document the expected format of Databricks job run ID
🔗 Analysis chain

Validate task parameters

The training task parameters include sensitive information (git_sha, git_branch, job_run_id). Consider:

  1. Adding parameter validation in the Python script
  2. Documenting the expected parameter formats

Let's check if the Python script validates these parameters:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for parameter validation in train_model.py
ast-grep --pattern 'def main($$$):
  $$$
  parser.add_argument($$$)
  $$$' src/mlops_with_databricks/pipeline/train_model.py

Length of output: 126


Script:

#!/bin/bash
# Let's first check if the file exists and see its content
fd "train_model.py" --type f --exec cat {}

Length of output: 3677

libraries:
- whl: ./dist/*.whl
- task_key: "evaluate_model"
depends_on:
- task_key: "train_model"
job_cluster_key: "ad-click-cluster"
spark_python_task:
python_file: "src/mlops_with_databricks/pipeline/evaluate_model.py"
parameters:
- "--new_model_uri"
- "{{tasks.train_model.values.new_model_uri}}"
- "--job_run_id"
- "{{job.id}}"
- "--git_sha"
- ${var.git_sha}
- "--git_branch"
- ${var.git_branch}
libraries:
- whl: ./dist/*.whl
- task_key: model_update
condition_task:
op: "EQUAL_TO"
left: "{{tasks.evaluate_model.values.model_update}}"
right: "1"
depends_on:
- task_key: "evaluate_model"
- task_key: "deploy_model"
depends_on:
- task_key: "model_update"
outcome: "true"
job_cluster_key: "ad-click-cluster"
spark_python_task:
python_file: "src/mlops_with_databricks/pipeline/deploy_model.py"
libraries:
- whl: ./dist/*.whl

targets:
dev:
mode: development
default: true
workspace:
host: https://dbc-643c4c2b-d6c9.cloud.databricks.com
root_path: /Workspace/Users/[email protected]/.bundle/${bundle.target}/${bundle.name}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Review environment configuration

Several architectural concerns in the environment setup:

  1. All environments (dev/stage/prod) use the same Databricks host, which doesn't follow separation of concerns
  2. The workspace paths are using a personal email path ([email protected]), which isn't suitable for production
  3. Consider using environment-specific workspace paths and hosts

Consider restructuring the targets to use:

  1. Separate Databricks workspaces for each environment
  2. Organization-wide workspace paths instead of personal ones
  3. Environment-specific configurations through variables

Also applies to: 120-120, 125-125


## Optionally, there could be 'staging' or 'prod' targets here.
#
# prod:
# workspace:
# host: https://dbc-45ad9c70-3532.cloud.databricks.com
stage:
mode: development
workspace:
host: https://dbc-643c4c2b-d6c9.cloud.databricks.com
root_path: /Workspace/Users/[email protected]/.bundle/${bundle.target}/${bundle.name}

prod:
workspace:
host: https://dbc-643c4c2b-d6c9.cloud.databricks.com
root_path: /Workspace/Users/[email protected]/.bundle/${bundle.target}/${bundle.name}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies = [
"matplotlib>=3.9.2, <4",
"databricks-feature-engineering>=0.6, <1",
"pre-commit>=4.0.1",
"loguru>=0.7.2",
"build>=1.2.2.post1",
]

[project.optional-dependencies]
Expand Down
25 changes: 23 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,26 @@ function deploy_package() {
fi
}

# Run the function
deploy_package
function run_databricks_bundle_stage() {
echo "Running Databricks bundle deploy..."
databricks bundle deploy --profile DEFAULT --target stage
echo "Databricks bundle deploy successful"
echo "Generating data..."
"/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/.venv/bin/python" "/Users/arturmakowski/.vscode/extensions/databricks.databricks-2.4.8-darwin-arm64/resources/python/dbconnect-bootstrap.py" "/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/src/mlops_with_databricks/pipeline/generate_data.py"
echo "Data generated successfully"
echo "Running Databricks bundle run..."
databricks bundle run --profile DEFAULT --target stage
echo "Databricks bundle run successful"
}
Comment on lines +25 to +35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Critical: Improve script robustness and portability

The staging deployment function has several issues that need to be addressed:

  1. Hardcoded absolute paths make the script non-portable and environment-dependent
  2. Missing error handling for critical operations
  3. Duplicated data generation logic between stage and prod functions

Consider refactoring using this approach:

+# Define common variables
+PYTHON_PATH="python"  # Or use environment variable
+DATA_GENERATOR_SCRIPT="src/mlops_with_databricks/pipeline/generate_data.py"

+function run_databricks_command() {
+    local command=$1
+    local target=$2
+    
+    echo "Running Databricks bundle $command..."
+    if ! databricks bundle $command --profile DEFAULT --target $target; then
+        echo "Error: Databricks bundle $command failed"
+        return 1
+    fi
+    echo "Databricks bundle $command successful"
+}

+function generate_data() {
+    echo "Generating data..."
+    if ! $PYTHON_PATH "$DATA_GENERATOR_SCRIPT"; then
+        echo "Error: Data generation failed"
+        return 1
+    fi
+    echo "Data generated successfully"
+}

 function run_databricks_bundle_stage() {
-    echo "Running Databricks bundle deploy..."
-    databricks bundle deploy --profile DEFAULT --target stage
-    echo "Databricks bundle deploy successful"
-    echo "Generating data..."
-    "/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/.venv/bin/python" "/Users/arturmakowski/.vscode/extensions/databricks.databricks-2.4.8-darwin-arm64/resources/python/dbconnect-bootstrap.py" "/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/src/mlops_with_databricks/pipeline/generate_data.py"
-    echo "Data generated successfully"
-    echo "Running Databricks bundle run..."
-    databricks bundle run --profile DEFAULT --target stage
-    echo "Databricks bundle run successful"
+    run_databricks_command "deploy" "stage" || return 1
+    generate_data || return 1
+    run_databricks_command "run" "stage" || return 1
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
function run_databricks_bundle_stage() {
echo "Running Databricks bundle deploy..."
databricks bundle deploy --profile DEFAULT --target stage
echo "Databricks bundle deploy successful"
echo "Generating data..."
"/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/.venv/bin/python" "/Users/arturmakowski/.vscode/extensions/databricks.databricks-2.4.8-darwin-arm64/resources/python/dbconnect-bootstrap.py" "/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/src/mlops_with_databricks/pipeline/generate_data.py"
echo "Data generated successfully"
echo "Running Databricks bundle run..."
databricks bundle run --profile DEFAULT --target stage
echo "Databricks bundle run successful"
}
# Define common variables
PYTHON_PATH="python" # Or use environment variable
DATA_GENERATOR_SCRIPT="src/mlops_with_databricks/pipeline/generate_data.py"
function run_databricks_command() {
local command=$1
local target=$2
echo "Running Databricks bundle $command..."
if ! databricks bundle $command --profile DEFAULT --target $target; then
echo "Error: Databricks bundle $command failed"
return 1
fi
echo "Databricks bundle $command successful"
}
function generate_data() {
echo "Generating data..."
if ! $PYTHON_PATH "$DATA_GENERATOR_SCRIPT"; then
echo "Error: Data generation failed"
return 1
fi
echo "Data generated successfully"
}
function run_databricks_bundle_stage() {
run_databricks_command "deploy" "stage" || return 1
generate_data || return 1
run_databricks_command "run" "stage" || return 1
}


function run_databricks_bundle_prod() {
echo "Running Databricks bundle deploy..."
databricks bundle deploy --profile DEFAULT --target prod
echo "Databricks bundle deploy successful"
echo "Generating data..."
"/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/.venv/bin/python" "/Users/arturmakowski/.vscode/extensions/databricks.databricks-2.4.8-darwin-arm64/resources/python/dbconnect-bootstrap.py" "/Users/arturmakowski/Documents/Python_projects/marvelous-databricks-course-ArturMakowski/src/mlops_with_databricks/pipeline/generate_data.py"
echo "Data generated successfully"
echo "Running Databricks bundle run..."
databricks bundle run --profile DEFAULT --target prod
echo "Databricks bundle run successful"
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ class LightGBMConfig(TypedDict):
max_depth: int


light_gbm_config = LightGBMConfig(learning_rate=0.001, n_estimators=200, max_depth=10)
light_gbm_config = LightGBMConfig(learning_rate=0.1, n_estimators=400, max_depth=15)
Empty file.
27 changes: 27 additions & 0 deletions src/mlops_with_databricks/pipeline/deploy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""This script is used to deploy the model to the serving endpoint. The model version is fetched from the evaluate_model task."""

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput

from mlops_with_databricks.data_preprocessing.dataclasses import DatabricksConfig, ModelConfig, ModelServingConfig

workspace = WorkspaceClient()

Comment on lines +8 to +9
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for workspace initialization

The workspace client initialization should include error handling and potentially accept configuration parameters for different environments.

Consider this improvement:

-workspace = WorkspaceClient()
+try:
+    workspace = WorkspaceClient(
+        host=DatabricksConfig.workspace_url,  # Add this to DatabricksConfig
+        token=DatabricksConfig.access_token   # Add this to DatabricksConfig
+    )
+except Exception as e:
+    raise RuntimeError(f"Failed to initialize Databricks workspace client: {str(e)}")

Committable suggestion skipped: line range outside the PR's diff.


model_version = workspace.dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version")
Comment on lines +10 to +11
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for model version retrieval

The model version retrieval lacks error handling and validation. This could lead to deployment failures if the task value is missing or invalid.

Suggested improvement:

-model_version = workspace.dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version")
+try:
+    model_version = workspace.dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version")
+    if not model_version:
+        raise ValueError("Model version not found in task values")
+    # Validate model version format if applicable
+    if not isinstance(model_version, (int, str)):
+        raise TypeError(f"Invalid model version type: {type(model_version)}")
+except Exception as e:
+    raise RuntimeError(f"Failed to retrieve model version: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_version = workspace.dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version")
try:
model_version = workspace.dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version")
if not model_version:
raise ValueError("Model version not found in task values")
# Validate model version format if applicable
if not isinstance(model_version, (int, str)):
raise TypeError(f"Invalid model version type: {type(model_version)}")
except Exception as e:
raise RuntimeError(f"Failed to retrieve model version: {str(e)}")



catalog_name = DatabricksConfig.catalog_name
schema_name = DatabricksConfig.schema_name

Comment on lines +14 to +16
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate configuration values

Configuration values should be validated to ensure they are not empty or invalid before use.

Suggested improvement:

+def validate_config_value(name: str, value: str) -> str:
+    if not value or not isinstance(value, str):
+        raise ValueError(f"Invalid {name}: {value}")
+    return value
+
-catalog_name = DatabricksConfig.catalog_name
-schema_name = DatabricksConfig.schema_name
+catalog_name = validate_config_value("catalog_name", DatabricksConfig.catalog_name)
+schema_name = validate_config_value("schema_name", DatabricksConfig.schema_name)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
catalog_name = DatabricksConfig.catalog_name
schema_name = DatabricksConfig.schema_name
def validate_config_value(name: str, value: str) -> str:
if not value or not isinstance(value, str):
raise ValueError(f"Invalid {name}: {value}")
return value
catalog_name = validate_config_value("catalog_name", DatabricksConfig.catalog_name)
schema_name = validate_config_value("schema_name", DatabricksConfig.schema_name)

workspace.serving_endpoints.update_config_and_wait(
name=ModelServingConfig.serving_endpoint_name,
served_entities=[
ServedEntityInput(
entity_name=f"{catalog_name}.{schema_name}.{ModelConfig.model_name}",
scale_to_zero_enabled=True,
workload_size="Small",
entity_version=model_version,
)
],
)
Comment on lines +17 to +27
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Enhance deployment robustness and monitoring

The endpoint update lacks error handling, timeout configuration, and deployment validation.

Consider these improvements:

+import time
+from databricks.sdk.service.serving import EndpointStateResponse
+
+def wait_for_endpoint_ready(workspace: WorkspaceClient, endpoint_name: str, timeout_seconds: int = 300) -> None:
+    start_time = time.time()
+    while time.time() - start_time < timeout_seconds:
+        state = workspace.serving_endpoints.get_state(name=endpoint_name)
+        if state.ready:
+            return
+        time.sleep(10)
+    raise TimeoutError(f"Endpoint {endpoint_name} not ready after {timeout_seconds} seconds")
+
+try:
     workspace.serving_endpoints.update_config_and_wait(
         name=ModelServingConfig.serving_endpoint_name,
         served_entities=[
             ServedEntityInput(
                 entity_name=f"{catalog_name}.{schema_name}.{ModelConfig.model_name}",
                 scale_to_zero_enabled=True,
                 workload_size="Small",
                 entity_version=model_version,
             )
         ],
+        timeout=300  # 5 minutes timeout
     )
+    # Validate deployment
+    wait_for_endpoint_ready(workspace, ModelServingConfig.serving_endpoint_name)
+    print(f"Successfully deployed model version {model_version} to endpoint {ModelServingConfig.serving_endpoint_name}")
+except Exception as e:
+    raise RuntimeError(f"Failed to deploy model to serving endpoint: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
workspace.serving_endpoints.update_config_and_wait(
name=ModelServingConfig.serving_endpoint_name,
served_entities=[
ServedEntityInput(
entity_name=f"{catalog_name}.{schema_name}.{ModelConfig.model_name}",
scale_to_zero_enabled=True,
workload_size="Small",
entity_version=model_version,
)
],
)
import time
from databricks.sdk.service.serving import EndpointStateResponse
def wait_for_endpoint_ready(workspace: WorkspaceClient, endpoint_name: str, timeout_seconds: int = 300) -> None:
start_time = time.time()
while time.time() - start_time < timeout_seconds:
state = workspace.serving_endpoints.get_state(name=endpoint_name)
if state.ready:
return
time.sleep(10)
raise TimeoutError(f"Endpoint {endpoint_name} not ready after {timeout_seconds} seconds")
try:
workspace.serving_endpoints.update_config_and_wait(
name=ModelServingConfig.serving_endpoint_name,
served_entities=[
ServedEntityInput(
entity_name=f"{catalog_name}.{schema_name}.{ModelConfig.model_name}",
scale_to_zero_enabled=True,
workload_size="Small",
entity_version=model_version,
)
],
timeout=300 # 5 minutes timeout
)
# Validate deployment
wait_for_endpoint_ready(workspace, ModelServingConfig.serving_endpoint_name)
print(f"Successfully deployed model version {model_version} to endpoint {ModelServingConfig.serving_endpoint_name}")
except Exception as e:
raise RuntimeError(f"Failed to deploy model to serving endpoint: {str(e)}")

121 changes: 121 additions & 0 deletions src/mlops_with_databricks/pipeline/evaluate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Evaluate the model and register it if it performs better than the previous model."""

import argparse
import sys

import mlflow
import mlflow.sklearn
from databricks import feature_engineering
from databricks.sdk import WorkspaceClient
from loguru import logger
from pyspark.sql import SparkSession
from sklearn.metrics import f1_score

from mlops_with_databricks.data_preprocessing.dataclasses import (
DatabricksConfig,
ModelServingConfig,
ProcessedAdClickDataConfig,
)

logger.remove()

logger.add(sink=sys.stderr, level="DEBUG")

parser = argparse.ArgumentParser()
parser.add_argument(
"--new_model_uri",
action="store",
default=None,
type=str,
required=True,
)

parser.add_argument(
"--job_run_id",
action="store",
default=None,
type=str,
required=True,
)

parser.add_argument(
"--git_sha",
action="store",
default=None,
type=str,
required=True,
)

parser.add_argument(
"--git_branch",
action="store",
default=None,
type=str,
required=True,
)


args = parser.parse_args()
new_model_uri = args.new_model_uri
job_run_id = args.job_run_id
git_sha = args.git_sha
git_branch = args.git_branch


spark = SparkSession.builder.getOrCreate()
workspace = WorkspaceClient()
fe = feature_engineering.FeatureEngineeringClient()

mlflow.set_registry_uri("databricks-uc")
mlflow.set_tracking_uri("databricks")

num_features = ProcessedAdClickDataConfig.num_features
cat_features = ProcessedAdClickDataConfig.cat_features
target = ProcessedAdClickDataConfig.target
catalog_name = DatabricksConfig.catalog_name
schema_name = DatabricksConfig.schema_name

serving_endpoint_name = ModelServingConfig.serving_endpoint_name
serving_endpoint = workspace.serving_endpoints.get(serving_endpoint_name)
model_name = serving_endpoint.config.served_models[0].model_name
model_version = serving_endpoint.config.served_models[0].model_version
previous_model_uri = f"models:/{model_name}/{model_version}"

test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas()

X_test = test_set[list(num_features) + list(cat_features)]
y_test = test_set[target]

logger.debug(f"New Model URI: {new_model_uri}")
logger.debug(f"Previous Model URI: {previous_model_uri}")

model_new = mlflow.sklearn.load_model(new_model_uri)
predictions_new = model_new.predict(X_test)

model_previous = mlflow.sklearn.load_model(previous_model_uri)
predictions_previous = model_previous.predict(X_test)

Comment on lines +92 to +97
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add exception handling when loading models and making predictions.

Loading models and making predictions might fail due to issues like incorrect URIs or model incompatibilities. Adding exception handling will improve the robustness of the script.

You can modify the code to include try-except blocks:

try:
    model_new = mlflow.sklearn.load_model(new_model_uri)
    predictions_new = model_new.predict(X_test)
except Exception as e:
    logger.error(f"Failed to load or predict with new model: {e}")
    sys.exit(1)

try:
    model_previous = mlflow.sklearn.load_model(previous_model_uri)
    predictions_previous = model_previous.predict(X_test)
except Exception as e:
    logger.error(f"Failed to load or predict with previous model: {e}")
    sys.exit(1)

logger.info(f"Predictions for New Model: {predictions_new}")
logger.info(f"Previous for Old Model: {predictions_previous}")


# Calculate F1 scores
f1_new = f1_score(y_test, predictions_new)
f1_previous = f1_score(y_test, predictions_previous)

logger.info(f"F1 Score for New Model: {f1_new}")
logger.info(f"F1 Score for Old Model: {f1_previous}")

if f1_new > f1_previous:
logger.info("New model performs better. Registering...")
model_version = mlflow.register_model(
model_uri=new_model_uri,
name=f"{catalog_name}.{schema_name}.ad_click_model_basic",
tags={"branch": git_branch, "git_sha": f"{git_sha}", "job_run_id": job_run_id},
)
workspace.dbutils.jobs.taskValues.set(key="model_version", value=model_version.version)
workspace.dbutils.jobs.taskValues.set(key="model_update", value=1)
logger.info(f"New model registered with version: {model_version.version}")
else:
logger.info("Previous model performs better. No update needed.")
workspace.dbutils.jobs.taskValues.set(key="model_update", value=0)
Loading