From 62e449a28026e6db6d14f1b9e8557cfe39423005 Mon Sep 17 00:00:00 2001 From: Ashish M Date: Tue, 26 Nov 2024 00:32:08 -0600 Subject: [PATCH] Separate config file by environment --- asset_bundles/wine_quality_model.yaml | 98 +++++++++ databricks.yml | 101 +--------- notebooks/week5/00.create_source_data.py | 48 ++--- project_config.yml => project_config_dev.yml | 10 + project_config_prod.yml | 40 ++++ tests/test_config.yml | 8 + week5/deploy_model.py | 57 ++++-- week5/evaluate_model.py | 202 ++++++++++--------- week5/preprocess.py | 55 +++-- week5/train_model.py | 41 ++-- week5/workspace_utils.py | 41 ++++ 11 files changed, 438 insertions(+), 263 deletions(-) create mode 100644 asset_bundles/wine_quality_model.yaml rename project_config.yml => project_config_dev.yml (68%) create mode 100644 project_config_prod.yml create mode 100644 week5/workspace_utils.py diff --git a/asset_bundles/wine_quality_model.yaml b/asset_bundles/wine_quality_model.yaml new file mode 100644 index 0000000..4e68bcd --- /dev/null +++ b/asset_bundles/wine_quality_model.yaml @@ -0,0 +1,98 @@ +resources: + jobs: + wine-quality: + name: wine-quality-workflow + schedule: + quartz_cron_expression: "0 0 10 ? * MONDAY *" + timezone_id: "America/Chicago" + pause_status: ${var.schedule_pause_status} + tags: + project_name: "wine-quality" + job_clusters: + - job_cluster_key: "wine-quality-cluster" + new_cluster: + spark_version: "15.4.x-scala2.12" + data_security_mode: "SINGLE_USER" + node_type_id: "r3.xlarge" + driver_node_type_id: "r3.xlarge" + autoscale: + min_workers: 1 + max_workers: 1 + + tasks: + - task_key: "preprocessing" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "../week5/preprocess.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--env" + - ${var.env} + libraries: + - whl: ../dist/*.whl + - task_key: if_refreshed + condition_task: + op: "EQUAL_TO" + left: "{{tasks.preprocessing.values.refreshed}}" + right: "1" + depends_on: + - task_key: "preprocessing" + - task_key: "train_model" + depends_on: + - task_key: "if_refreshed" + outcome: "true" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "../week5/train_model.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--git_sha" + - ${var.git_sha} + - "--job_run_id" + - "{{job.id}}" + - "--env" + - ${var.env} + libraries: + - whl: ../dist/*.whl + - task_key: "evaluate_model" + depends_on: + - task_key: "train_model" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "../week5/evaluate_model.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--new_model_uri" + - "{{tasks.train_model.values.new_model_uri}}" + - "--job_run_id" + - "{{job.id}}" + - "--git_sha" + - ${var.git_sha} + - "--env" + - ${var.env} + libraries: + - whl: ../dist/*.whl + - task_key: model_update + condition_task: + op: "EQUAL_TO" + left: "{{tasks.evaluate_model.values.model_update}}" + right: "1" + depends_on: + - task_key: "evaluate_model" + - task_key: "deploy_model" + depends_on: + - task_key: "model_update" + outcome: "true" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "../week5/deploy_model.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--env" + - ${var.env} + libraries: + - whl: ../dist/*.whl \ No newline at end of file diff --git a/databricks.yml b/databricks.yml index 0a1face..c819adf 100644 --- a/databricks.yml +++ b/databricks.yml @@ -7,6 +7,9 @@ bundle: #databricks_cli_version: "0.230.0" cluster_id: 1018-024853-shr990fc +include: + - asset_bundles/*.yaml + permissions: - level: CAN_MANAGE user_name: mahajan134@gmail.com @@ -29,9 +32,14 @@ variables: schedule_pause_status: description: schedule pause status default: UNPAUSED + env: + description: environment + default: dev targets: prod: + variables: + env: prod workspace: host: https://dbc-643c4c2b-d6c9.cloud.databricks.com root_path: ${var.root_path} @@ -39,97 +47,8 @@ targets: dev: mode: development default: true + variables: + env: dev workspace: host: https://dbc-643c4c2b-d6c9.cloud.databricks.com # root_path: /Workspace/Users/mahajan134@gmail.com/.bundle/${bundle.name}/dev - -resources: - jobs: - wine-quality: - name: wine-quality-workflow - schedule: - quartz_cron_expression: "0 0 10 ? * MONDAY *" - timezone_id: "America/Chicago" - pause_status: ${var.schedule_pause_status} - tags: - project_name: "wine-quality" - job_clusters: - - job_cluster_key: "wine-quality-cluster" - new_cluster: - spark_version: "15.4.x-scala2.12" - data_security_mode: "SINGLE_USER" - node_type_id: "r3.xlarge" - driver_node_type_id: "r3.xlarge" - autoscale: - min_workers: 1 - max_workers: 1 - - tasks: - - task_key: "preprocessing" - job_cluster_key: "wine-quality-cluster" - spark_python_task: - python_file: "week5/preprocess.py" - parameters: - - "--root_path" - - ${var.root_path} - libraries: - - whl: ./dist/*.whl - - task_key: if_refreshed - condition_task: - op: "EQUAL_TO" - left: "{{tasks.preprocessing.values.refreshed}}" - right: "1" - depends_on: - - task_key: "preprocessing" - - task_key: "train_model" - depends_on: - - task_key: "if_refreshed" - outcome: "true" - job_cluster_key: "wine-quality-cluster" - spark_python_task: - python_file: "week5/train_model.py" - parameters: - - "--root_path" - - ${var.root_path} - - "--git_sha" - - ${var.git_sha} - - "--job_run_id" - - "{{job.id}}" - libraries: - - whl: ./dist/*.whl - - task_key: "evaluate_model" - depends_on: - - task_key: "train_model" - job_cluster_key: "wine-quality-cluster" - spark_python_task: - python_file: "week5/evaluate_model.py" - parameters: - - "--root_path" - - ${var.root_path} - - "--new_model_uri" - - "{{tasks.train_model.values.new_model_uri}}" - - "--job_run_id" - - "{{job.id}}" - - "--git_sha" - - ${var.git_sha} - libraries: - - whl: ./dist/*.whl - - task_key: model_update - condition_task: - op: "EQUAL_TO" - left: "{{tasks.evaluate_model.values.model_update}}" - right: "1" - depends_on: - - task_key: "evaluate_model" - - task_key: "deploy_model" - depends_on: - - task_key: "model_update" - outcome: "true" - job_cluster_key: "wine-quality-cluster" - spark_python_task: - python_file: "week5/deploy_model.py" - parameters: - - "--root_path" - - ${var.root_path} - libraries: - - whl: ./dist/*.whl diff --git a/notebooks/week5/00.create_source_data.py b/notebooks/week5/00.create_source_data.py index 17bc1db..d63cc55 100644 --- a/notebooks/week5/00.create_source_data.py +++ b/notebooks/week5/00.create_source_data.py @@ -1,7 +1,8 @@ -import pandas as pd import numpy as np +import pandas as pd from pyspark.sql import SparkSession from pyspark.sql.functions import current_timestamp, to_utc_timestamp + from wine_quality.config import ProjectConfig # Load configuration @@ -14,55 +15,58 @@ train_set = spark.table(f"{catalog_name}.{schema_name}.train_set").toPandas() test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas() combined_set = pd.concat([train_set, test_set], ignore_index=True) -existing_ids = set(int(id) for id in combined_set['id']) +existing_ids = set(int(id) for id in combined_set["id"]) + # Define function to create synthetic data without random state def create_synthetic_data(df, num_rows=100): synthetic_data = pd.DataFrame() - + for column in df.columns: # Treat float and int differently - # if pd.api.types.is_numeric_dtype(df[column]) and column != 'id': - # mean, std = df[column].mean(), df[column].std() - # synthetic_data[column] = np.random.normal(mean, std, num_rows) + # if pd.api.types.is_numeric_dtype(df[column]) and column != 'id': + # mean, std = df[column].mean(), df[column].std() + # synthetic_data[column] = np.random.normal(mean, std, num_rows) if pd.api.types.is_float_dtype(df[column]): mean, std = df[column].mean(), df[column].std() synthetic_data[column] = np.random.normal(mean, std, num_rows) - elif pd.api.types.is_integer_dtype(df[column]) and column != 'id': + elif pd.api.types.is_integer_dtype(df[column]) and column != "id": mean, std = df[column].mean(), df[column].std() synthetic_data[column] = np.random.normal(mean, std, num_rows).astype(int) elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): - synthetic_data[column] = np.random.choice(df[column].unique(), num_rows, - p=df[column].value_counts(normalize=True)) - + synthetic_data[column] = np.random.choice( + df[column].unique(), num_rows, p=df[column].value_counts(normalize=True) + ) + elif pd.api.types.is_datetime64_any_dtype(df[column]): min_date, max_date = df[column].min(), df[column].max() if min_date < max_date: - synthetic_data[column] = pd.to_datetime( - np.random.randint(min_date.value, max_date.value, num_rows) - ) + synthetic_data[column] = pd.to_datetime(np.random.randint(min_date.value, max_date.value, num_rows)) else: synthetic_data[column] = [min_date] * num_rows - + else: synthetic_data[column] = np.random.choice(df[column], num_rows) - - # Making sure that generated IDs are unique and do not previously exist + + # Making sure that generated IDs are unique and do not previously exist new_ids = [] i = max(existing_ids) + 1 if existing_ids else 1 while len(new_ids) < num_rows: if i not in existing_ids: - new_ids.append(i) # Id needs to be string, but leaving it as int to match train/test set. Will convert to string later. - #new_ids.append(str(i)) # Convert numeric ID to string + new_ids.append( + i + ) # Id needs to be string, but leaving it as int to match train/test set. Will convert to string later. + # new_ids.append(str(i)) # Convert numeric ID to string i += 1 - synthetic_data['id'] = new_ids + synthetic_data["id"] = new_ids return synthetic_data + # Create synthetic data synthetic_df = create_synthetic_data(combined_set) -# Create source_data table manually using Create table like train_set +# Create source_data table manually using Create table like train_set existing_schema = spark.table(f"{catalog_name}.{schema_name}.source_data").schema synthetic_spark_df = spark.createDataFrame(synthetic_df, schema=existing_schema) @@ -72,6 +76,4 @@ def create_synthetic_data(df, num_rows=100): ) # Append synthetic data as new data to source_data table -train_set_with_timestamp.write.mode("append").saveAsTable( - f"{catalog_name}.{schema_name}.source_data" -) \ No newline at end of file +train_set_with_timestamp.write.mode("append").saveAsTable(f"{catalog_name}.{schema_name}.source_data") diff --git a/project_config.yml b/project_config_dev.yml similarity index 68% rename from project_config.yml rename to project_config_dev.yml index 03f6d7c..ff046bd 100644 --- a/project_config.yml +++ b/project_config_dev.yml @@ -2,6 +2,16 @@ catalog_name: mlops_students schema_name: mahajan134 pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 +dev: + catalog_name: mlops_students + schema_name: mahajan134 + pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 + +prod: + catalog_name: mlops_students + schema_name: mahajan134 + pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 + parameters: learning_rate: 0.01 n_estimators: 1000 diff --git a/project_config_prod.yml b/project_config_prod.yml new file mode 100644 index 0000000..ff046bd --- /dev/null +++ b/project_config_prod.yml @@ -0,0 +1,40 @@ +catalog_name: mlops_students +schema_name: mahajan134 +pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 + +dev: + catalog_name: mlops_students + schema_name: mahajan134 + pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 + +prod: + catalog_name: mlops_students + schema_name: mahajan134 + pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 + +parameters: + learning_rate: 0.01 + n_estimators: 1000 + max_depth: 6 + +ab_test: + learning_rate_a: 0.02 + learning_rate_b: 0.02 + n_estimators: 1000 + max_depth_a: 6 + max_depth_b: 10 + +num_features: + - fixed_acidity + - volatile_acidity + - citric_acid + - residual_sugar + - chlorides + - free_sulfur_dioxide + - total_sulfur_dioxide + - density + - pH + - sulphates + - alcohol + +target: quality diff --git a/tests/test_config.yml b/tests/test_config.yml index 20f1ed4..03f6d7c 100644 --- a/tests/test_config.yml +++ b/tests/test_config.yml @@ -1,11 +1,19 @@ catalog_name: mlops_students schema_name: mahajan134 +pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 parameters: learning_rate: 0.01 n_estimators: 1000 max_depth: 6 +ab_test: + learning_rate_a: 0.02 + learning_rate_b: 0.02 + n_estimators: 1000 + max_depth_a: 6 + max_depth_b: 10 + num_features: - fixed_acidity - volatile_acidity diff --git a/week5/deploy_model.py b/week5/deploy_model.py index 3454247..041d31e 100644 --- a/week5/deploy_model.py +++ b/week5/deploy_model.py @@ -11,12 +11,14 @@ The endpoint is configured for feature-engineered model serving with automatic scaling. """ -import yaml import argparse + from databricks.sdk import WorkspaceClient -from databricks.sdk.service.serving import ServedEntityInput +from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput from pyspark.dbutils import DBUtils from pyspark.sql import SparkSession +from workspace_utils import check_if_endpoint_exists, get_env_config_file + from wine_quality.config import ProjectConfig parser = argparse.ArgumentParser() @@ -27,13 +29,23 @@ type=str, required=True, ) +parser.add_argument( + "--env", + action="store", + default=None, + type=str, + required=True, +) args = parser.parse_args() root_path = args.root_path -config_path = (f"{root_path}/project_config.yml") +config_file_name = get_env_config_file(args.env) +config_path = f"{root_path}/{config_file_name}" # config_path = ("/Volumes/mlops_students/mahajan134/mlops_vol/project_config.yml") config = ProjectConfig.from_yaml(config_path=config_path) +endpoint_name = "wine-quality-model-serving-fe" +model_name = "wine-quality-model-fe" spark = SparkSession.builder.getOrCreate() dbutils = DBUtils(spark) @@ -45,14 +57,31 @@ catalog_name = config.catalog_name schema_name = config.schema_name -workspace.serving_endpoints.update_config_and_wait( - name="wine-quality-model-serving-fe", - served_entities=[ - ServedEntityInput( - entity_name=f"{catalog_name}.{schema_name}.wine-quality-model-fe", - scale_to_zero_enabled=True, - workload_size="Small", - entity_version=model_version, - ) - ], -) \ No newline at end of file +if check_if_endpoint_exists(workspace, endpoint_name): + print(f"Endpoint '{endpoint_name}' exists.") + workspace.serving_endpoints.update_config_and_wait( + name=endpoint_name, + served_entities=[ + ServedEntityInput( + entity_name=f"{catalog_name}.{schema_name}.{model_name}", + scale_to_zero_enabled=True, + workload_size="Small", + entity_version=model_version, + ) + ], + ) +else: + print(f"Endpoint '{endpoint_name}' does not exist, creating new endpoint.") + workspace.serving_endpoints.create_and_wait( + name=endpoint_name, + config=EndpointCoreConfigInput( + served_entities=[ + ServedEntityInput( + entity_name=f"{catalog_name}.{schema_name}.{model_name}", + scale_to_zero_enabled=True, + workload_size="Small", + entity_version=1, + ) + ] + ), + ) diff --git a/week5/evaluate_model.py b/week5/evaluate_model.py index e1b719a..7c94a25 100644 --- a/week5/evaluate_model.py +++ b/week5/evaluate_model.py @@ -16,25 +16,20 @@ 6. Updates pipeline task values with results """ +import argparse + +import mlflow from databricks import feature_engineering from databricks.sdk import WorkspaceClient -from pyspark.sql import SparkSession from pyspark.dbutils import DBUtils - -from pyspark.sql import functions as F -from pyspark.sql.functions import col, lit, when from pyspark.ml.evaluation import RegressionEvaluator -from datetime import datetime -import mlflow -import argparse +from pyspark.sql import SparkSession from pyspark.sql import functions as F -from pyspark.sql import DataFrame -from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.sql.functions import col, lit, when +from workspace_utils import check_if_endpoint_exists, get_env_config_file, get_model_entity_index from wine_quality.config import ProjectConfig - - parser = argparse.ArgumentParser() parser.add_argument( "--root_path", @@ -67,6 +62,13 @@ required=True, ) +parser.add_argument( + "--env", + action="store", + default=None, + type=str, + required=True, +) args = parser.parse_args() root_path = args.root_path @@ -74,8 +76,10 @@ job_run_id = args.job_run_id git_sha = args.git_sha -config_path = (f"{root_path}/project_config.yml") +# config_path = (f"{root_path}/project_config.yml") # config_path = ("/Volumes/mlops_test/house_prices/data/project_config.yml") +config_file_name = get_env_config_file(args.env) +config_path = f"{root_path}/{config_file_name}" config = ProjectConfig.from_yaml(config_path=config_path) spark = SparkSession.builder.getOrCreate() @@ -86,96 +90,104 @@ mlflow.set_registry_uri("databricks-uc") mlflow.set_tracking_uri("databricks") -def get_latest_entity_version(serving_endpoint): - latest_version = 1 - for entity in serving_endpoint.config.served_entities: - version_int = int(entity.entity_version) - if version_int > latest_version: - latest_version = version_int - return latest_version - # Extract configuration details num_features = config.num_features target = config.target catalog_name = config.catalog_name schema_name = config.schema_name - -# Define the serving endpoint -serving_endpoint_name = "wine-quality-model-serving-fe" -serving_endpoint = workspace.serving_endpoints.get(serving_endpoint_name) - -model_name = serving_endpoint.config.served_models[0].model_name -model_version = serving_endpoint.config.served_models[0].model_version -# latest_entity = get_latest_entity_version(serving_endpoint) -# print(latest_entity) -# model_name = serving_endpoint.config.served_entities[latest_entity].model_name -# model_version = serving_endpoint.config.served_entities[latest_entity].model_version - -previous_model_uri = f"models:/{model_name}/{model_version}" -cur_model_name = "wine-quality-model-fe" - - -# Load test set and create additional features in Spark DataFrame -test_set = spark.table(f"{catalog_name}.{schema_name}.test_set") -# Cast residual_sugar to int for the function inputs -test_set = test_set.withColumn("residual_sugar", test_set["residual_sugar"].cast("int")) -test_set = test_set.withColumn("is_sweet_indicator", when(col("residual_sugar") > 40, lit('1')).otherwise(lit('0'))) -test_set = test_set.withColumn("is_sweet_indicator", test_set["residual_sugar"].cast("int")) -test_set = test_set.withColumn("Id", test_set["id"].cast("string")) - - -# Select the necessary columns for prediction and target -X_test_spark = test_set.select(num_features + ["is_sweet_indicator", "Id"]) -y_test_spark = test_set.select("Id", target) - - -# Generate predictions from both models -predictions_previous = fe.score_batch(model_uri=previous_model_uri, df=X_test_spark) -predictions_new = fe.score_batch(model_uri=new_model_uri, df=X_test_spark) - -predictions_new = predictions_new.withColumnRenamed("prediction", "prediction_new") -predictions_old = predictions_previous.withColumnRenamed("prediction", "prediction_old") -test_set = test_set.select("Id", "quality") - -# Join the DataFrames on the 'id' column -df = test_set \ - .join(predictions_new, on="Id") \ - .join(predictions_old, on="Id") - -# Calculate the absolute error for each model -df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) -df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) - -# Calculate the absolute error for each model -df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) -df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) - -# Calculate the Mean Absolute Error (MAE) for each model -mae_new = df.agg(F.mean("error_new")).collect()[0][0] -mae_old = df.agg(F.mean("error_old")).collect()[0][0] - -# Calculate the Root Mean Squared Error (RMSE) for each model -evaluator = RegressionEvaluator(labelCol="quality", predictionCol="prediction_new", metricName="rmse") -rmse_new = evaluator.evaluate(df) - -evaluator.setPredictionCol("prediction_old") -rmse_old = evaluator.evaluate(df) - -# Compare models based on MAE and RMSE -print(f"MAE for New Model: {mae_new}") -print(f"MAE for Old Model: {mae_old}") - -if mae_new < mae_old: - print("New model is better based on MAE.") +endpoint_name = "wine-quality-model-serving-fe" +model_name = "wine-quality-model-fe" +full_model_name = f"{catalog_name}.{schema_name}.{model_name}" + + +if check_if_endpoint_exists(workspace, endpoint_name): + # Define the serving endpoint + serving_endpoint = workspace.serving_endpoints.get(endpoint_name) + + # This assumes that only one model is served by the endpoint + # model_name = serving_endpoint.config.served_models[0].model_name + # model_version = serving_endpoint.config.served_models[0].model_version + # print(model_name) + # print(model_version) + + # Loop over served entities and get model name and version. Can be used for multiple models. + # served_entity_version = model_entity_version(workspace, endpoint_name, f"{catalog_name}.{schema_name}.{model_name}") + served_model_index = get_model_entity_index(workspace, endpoint_name, full_model_name) + + if served_model_index < 0: + raise ValueError("Model not found in the serving endpoint.") + + model_name = serving_endpoint.config.served_models[served_model_index].model_name + model_version = serving_endpoint.config.served_models[served_model_index].model_version + print(model_name) + print(model_version) + + previous_model_uri = f"models:/{model_name}/{model_version}" + + # Load test set and create additional features in Spark DataFrame + test_set = spark.table(f"{catalog_name}.{schema_name}.test_set") + # Cast residual_sugar to int for the function inputs + test_set = test_set.withColumn("residual_sugar", test_set["residual_sugar"].cast("int")) + test_set = test_set.withColumn("is_sweet_indicator", when(col("residual_sugar") > 40, lit("1")).otherwise(lit("0"))) + test_set = test_set.withColumn("is_sweet_indicator", test_set["residual_sugar"].cast("int")) + test_set = test_set.withColumn("Id", test_set["id"].cast("string")) + + # Select the necessary columns for prediction and target + X_test_spark = test_set.select(num_features + ["is_sweet_indicator", "Id"]) + y_test_spark = test_set.select("Id", target) + + # Generate predictions from both models + predictions_previous = fe.score_batch(model_uri=previous_model_uri, df=X_test_spark) + predictions_new = fe.score_batch(model_uri=new_model_uri, df=X_test_spark) + + predictions_new = predictions_new.withColumnRenamed("prediction", "prediction_new") + predictions_old = predictions_previous.withColumnRenamed("prediction", "prediction_old") + test_set = test_set.select("Id", "quality") + + # Join the DataFrames on the 'id' column + df = test_set.join(predictions_new, on="Id").join(predictions_old, on="Id") + + # Calculate the absolute error for each model + df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) + df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) + + # Calculate the absolute error for each model + df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) + df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) + + # Calculate the Mean Absolute Error (MAE) for each model + mae_new = df.agg(F.mean("error_new")).collect()[0][0] + mae_old = df.agg(F.mean("error_old")).collect()[0][0] + + # Calculate the Root Mean Squared Error (RMSE) for each model + evaluator = RegressionEvaluator(labelCol="quality", predictionCol="prediction_new", metricName="rmse") + rmse_new = evaluator.evaluate(df) + + evaluator.setPredictionCol("prediction_old") + rmse_old = evaluator.evaluate(df) + + # Compare models based on MAE and RMSE + print(f"MAE for New Model: {mae_new}") + print(f"MAE for Old Model: {mae_old}") + + if mae_new < mae_old: + print("New model is better based on MAE.") + model_version = mlflow.register_model( + model_uri=new_model_uri, name=full_model_name, tags={"git_sha": f"{git_sha}", "job_run_id": job_run_id} + ) + + print("New model registered with version:", model_version.version) + dbutils.jobs.taskValues.set(key="model_version", value=model_version.version) + dbutils.jobs.taskValues.set(key="model_update", value=1) + else: + print("Old model is better based on MAE.") + dbutils.jobs.taskValues.set(key="model_update", value=0) +else: + print("Endpoint not found, registering new model.") model_version = mlflow.register_model( - model_uri=new_model_uri, - name=f"{catalog_name}.{schema_name}.{cur_model_name}", - tags={"git_sha": f"{git_sha}", - "job_run_id": job_run_id}) + model_uri=new_model_uri, name=full_model_name, tags={"git_sha": f"{git_sha}", "job_run_id": job_run_id} + ) print("New model registered with version:", model_version.version) dbutils.jobs.taskValues.set(key="model_version", value=model_version.version) dbutils.jobs.taskValues.set(key="model_update", value=1) -else: - print("Old model is better based on MAE.") - dbutils.jobs.taskValues.set(key="model_update", value=0) \ No newline at end of file diff --git a/week5/preprocess.py b/week5/preprocess.py index 83712ef..f45191b 100644 --- a/week5/preprocess.py +++ b/week5/preprocess.py @@ -18,16 +18,17 @@ 6. Set a task value indicating whether new data was processed. """ -import yaml import argparse +import time + +from databricks.sdk import WorkspaceClient from pyspark.dbutils import DBUtils from pyspark.sql import SparkSession -from pyspark.sql.functions import col, max as spark_max -import yaml -from databricks.sdk import WorkspaceClient -import time -from wine_quality.config import ProjectConfig +from pyspark.sql.functions import col +from pyspark.sql.functions import max as spark_max +from workspace_utils import get_env_config_file +from wine_quality.config import ProjectConfig workspace = WorkspaceClient() @@ -39,11 +40,19 @@ type=str, required=True, ) - +parser.add_argument( + "--env", + action="store", + default=None, + type=str, + required=True, +) args = parser.parse_args() root_path = args.root_path print(f"Root path: {root_path}") -config_path = (f"{root_path}/project_config.yml") + +config_file_name = get_env_config_file(args.env) +config_path = f"{root_path}/{config_file_name}" config = ProjectConfig.from_yaml(config_path=config_path) pipeline_id = config.pipeline_id @@ -59,13 +68,17 @@ source_data = spark.table(f"{catalog_name}.{schema_name}.source_data") # Get max update timestamps from existing data -max_train_timestamp = spark.table(f"{catalog_name}.{schema_name}.train_set") \ - .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) \ +max_train_timestamp = ( + spark.table(f"{catalog_name}.{schema_name}.train_set") + .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) .collect()[0]["max_update_timestamp"] +) -max_test_timestamp = spark.table(f"{catalog_name}.{schema_name}.test_set") \ - .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) \ +max_test_timestamp = ( + spark.table(f"{catalog_name}.{schema_name}.test_set") + .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) .collect()[0]["max_update_timestamp"] +) latest_timestamp = max(max_train_timestamp, max_test_timestamp) @@ -83,8 +96,8 @@ affected_rows_train = new_data_train.count() affected_rows_test = new_data_test.count() -#write into feature table; update online table -if affected_rows_train > 0 or affected_rows_test > 0 : +# write into feature table; update online table +if affected_rows_train > 0 or affected_rows_test > 0: spark.sql(f""" WITH max_timestamp AS ( SELECT MAX(update_timestamp_utc) AS max_update_timestamp @@ -106,17 +119,15 @@ WHERE update_timestamp_utc == (SELECT max_update_timestamp FROM max_timestamp) """) refreshed = 1 - update_response = workspace.pipelines.start_update( - pipeline_id=pipeline_id, full_refresh=False) + update_response = workspace.pipelines.start_update(pipeline_id=pipeline_id, full_refresh=False) while True: - update_info = workspace.pipelines.get_update(pipeline_id=pipeline_id, - update_id=update_response.update_id) + update_info = workspace.pipelines.get_update(pipeline_id=pipeline_id, update_id=update_response.update_id) state = update_info.update.state.value - if state == 'COMPLETED': + if state == "COMPLETED": break - elif state in ['FAILED', 'CANCELED']: + elif state in ["FAILED", "CANCELED"]: raise SystemError("Online table failed to update.") - elif state == 'WAITING_FOR_RESOURCES': + elif state == "WAITING_FOR_RESOURCES": print("Pipeline is waiting for resources.") else: print(f"Pipeline is in {state} state.") @@ -124,4 +135,4 @@ else: refreshed = 0 -dbutils.jobs.taskValues.set(key="refreshed", value=refreshed) \ No newline at end of file +dbutils.jobs.taskValues.set(key="refreshed", value=refreshed) diff --git a/week5/train_model.py b/week5/train_model.py index c979c50..0b0ca04 100644 --- a/week5/train_model.py +++ b/week5/train_model.py @@ -12,25 +12,24 @@ The model uses numerical features, including a custom calculated wine sweetness feature. """ +import argparse + +import mlflow from databricks import feature_engineering -from pyspark.dbutils import DBUtils -from pyspark.sql import SparkSession +from databricks.feature_engineering import FeatureFunction, FeatureLookup from databricks.sdk import WorkspaceClient -import mlflow -import argparse -from datetime import datetime from lightgbm import LGBMRegressor from mlflow.models import infer_signature -from pyspark.sql import functions as F +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession from sklearn.compose import ColumnTransformer from sklearn.impute import SimpleImputer -from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler -from databricks.feature_engineering import FeatureFunction, FeatureLookup -from wine_quality.config import ProjectConfig - +from workspace_utils import get_env_config_file +from wine_quality.config import ProjectConfig parser = argparse.ArgumentParser() parser.add_argument( @@ -54,6 +53,13 @@ type=str, required=True, ) +parser.add_argument( + "--env", + action="store", + default=None, + type=str, + required=True, +) args = parser.parse_args() root_path = args.root_path @@ -63,9 +69,10 @@ cur_model_artifact_path = "wine-quality-model-fe" -config_path = (f"{root_path}/project_config.yml") +# config_path = (f"{root_path}/project_config.yml") # config_path = ("/Volumes/mlops_students/mahajan134/mlops_vol/project_config.yml") - +config_file_name = get_env_config_file(args.env) +config_path = f"{root_path}/{config_file_name}" config = ProjectConfig.from_yaml(config_path=config_path) # Initialize the Databricks session and clients @@ -114,7 +121,7 @@ input_bindings={"residual_sugar_content": "residual_sugar"}, ), ], - exclude_columns=["update_timestamp_utc"] + exclude_columns=["update_timestamp_utc"], ) # Load feature-engineered DataFrame @@ -141,9 +148,7 @@ mlflow.set_experiment(experiment_name=cur_experiment_name) -with mlflow.start_run(tags={"branch": "week5", - "git_sha": f"{git_sha}", - "job_run_id": job_run_id}) as run: +with mlflow.start_run(tags={"branch": "week5", "git_sha": f"{git_sha}", "job_run_id": job_run_id}) as run: run_id = run.info.run_id pipeline.fit(X_train, y_train) y_pred = pipeline.predict(X_test) @@ -173,5 +178,5 @@ signature=signature, ) -model_uri=f'runs:/{run_id}/{cur_model_artifact_path}' -dbutils.jobs.taskValues.set(key="new_model_uri", value=model_uri) \ No newline at end of file +model_uri = f"runs:/{run_id}/{cur_model_artifact_path}" +dbutils.jobs.taskValues.set(key="new_model_uri", value=model_uri) diff --git a/week5/workspace_utils.py b/week5/workspace_utils.py new file mode 100644 index 0000000..eac9b18 --- /dev/null +++ b/week5/workspace_utils.py @@ -0,0 +1,41 @@ +from databricks.sdk import WorkspaceClient + + +def check_if_endpoint_exists(client, endpoint_name): + try: + client.serving_endpoints.get(endpoint_name) + return True + except Exception as e: + if "RESOURCE_DOES_NOT_EXIST" in str(e): + return False + # raise e Don't raise the exception, just return False so we can create the endpoint + + +def get_latest_entity_version(client, endpoint_name): + serving_endpoint = client.serving_endpoints.get(endpoint_name) + latest_version = 1 + for entity in serving_endpoint.config.served_entities: + version_int = int(entity.entity_version) + if version_int > latest_version: + latest_version = version_int + return latest_version + + +def get_model_entity_index(client, endpoint_name, full_model_name) -> int: + serving_endpoint = client.serving_endpoints.get(endpoint_name) + model_index : int = -1 # endpoint only serving one model + for idx, entity in enumerate(serving_endpoint.config.served_entities): + if entity.entity_name == full_model_name: + model_index = int(idx) + break + return model_index + + +def get_env_config_file(env:str) -> str: + if env == "dev": + return "project_config_dev.yml" + elif env == "prod": + return "project_config_prod.yml" + else: + raise ValueError(f"Invalid environment: {env}") +