diff --git a/.gitignore b/.gitignore index 6f423de..c1c31ff 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,8 @@ __pycache__/ *.so # Folders -# data/ +mlruns/ +#data/ # Distribution / packaging .Python diff --git a/databricks.yml b/databricks.yml index b6780ba..0a1face 100644 --- a/databricks.yml +++ b/databricks.yml @@ -4,16 +4,132 @@ bundle: name: marvelous-databricks-course-a0134m + #databricks_cli_version: "0.230.0" + cluster_id: 1018-024853-shr990fc + +permissions: + - level: CAN_MANAGE + user_name: mahajan134@gmail.com + +artifacts: + default: + type: whl + build: uv build --wheel + # build: python -m build + path: . + +variables: + root_path: + description: root_path for the target + # default: /Shared/.bundle/${bundle.target}/${bundle.name} + default: /Workspace/Users/mahajan134@gmail.com/.bundle/marvelous-databricks-course-a0134m/dev/files + git_sha: + description: git_sha + default: abcd + schedule_pause_status: + description: schedule pause status + default: UNPAUSED targets: + prod: + workspace: + host: https://dbc-643c4c2b-d6c9.cloud.databricks.com + root_path: ${var.root_path} + dev: mode: development default: true workspace: host: https://dbc-643c4c2b-d6c9.cloud.databricks.com + # root_path: /Workspace/Users/mahajan134@gmail.com/.bundle/${bundle.name}/dev + +resources: + jobs: + wine-quality: + name: wine-quality-workflow + schedule: + quartz_cron_expression: "0 0 10 ? * MONDAY *" + timezone_id: "America/Chicago" + pause_status: ${var.schedule_pause_status} + tags: + project_name: "wine-quality" + job_clusters: + - job_cluster_key: "wine-quality-cluster" + new_cluster: + spark_version: "15.4.x-scala2.12" + data_security_mode: "SINGLE_USER" + node_type_id: "r3.xlarge" + driver_node_type_id: "r3.xlarge" + autoscale: + min_workers: 1 + max_workers: 1 - ## Optionally, there could be 'staging' or 'prod' targets here. - # - # prod: - # workspace: - # host: https://dbc-643c4c2b-d6c9.cloud.databricks.com + tasks: + - task_key: "preprocessing" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "week5/preprocess.py" + parameters: + - "--root_path" + - ${var.root_path} + libraries: + - whl: ./dist/*.whl + - task_key: if_refreshed + condition_task: + op: "EQUAL_TO" + left: "{{tasks.preprocessing.values.refreshed}}" + right: "1" + depends_on: + - task_key: "preprocessing" + - task_key: "train_model" + depends_on: + - task_key: "if_refreshed" + outcome: "true" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "week5/train_model.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--git_sha" + - ${var.git_sha} + - "--job_run_id" + - "{{job.id}}" + libraries: + - whl: ./dist/*.whl + - task_key: "evaluate_model" + depends_on: + - task_key: "train_model" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "week5/evaluate_model.py" + parameters: + - "--root_path" + - ${var.root_path} + - "--new_model_uri" + - "{{tasks.train_model.values.new_model_uri}}" + - "--job_run_id" + - "{{job.id}}" + - "--git_sha" + - ${var.git_sha} + libraries: + - whl: ./dist/*.whl + - task_key: model_update + condition_task: + op: "EQUAL_TO" + left: "{{tasks.evaluate_model.values.model_update}}" + right: "1" + depends_on: + - task_key: "evaluate_model" + - task_key: "deploy_model" + depends_on: + - task_key: "model_update" + outcome: "true" + job_cluster_key: "wine-quality-cluster" + spark_python_task: + python_file: "week5/deploy_model.py" + parameters: + - "--root_path" + - ${var.root_path} + libraries: + - whl: ./dist/*.whl diff --git a/notebooks/week2/05.log_and_register_fe_model.py b/notebooks/week2/05.log_and_register_fe_model.py index 22b63d1..ac5b4e1 100644 --- a/notebooks/week2/05.log_and_register_fe_model.py +++ b/notebooks/week2/05.log_and_register_fe_model.py @@ -60,7 +60,7 @@ # COMMAND ---------- -# Create or replace the house_features table +# Create or replace the wine_features table spark.sql(f""" CREATE OR REPLACE TABLE {catalog_name}.{schema_name}.wine_features (Id STRING NOT NULL, diff --git a/notebooks/week5/00.create_source_data.py b/notebooks/week5/00.create_source_data.py new file mode 100644 index 0000000..17bc1db --- /dev/null +++ b/notebooks/week5/00.create_source_data.py @@ -0,0 +1,77 @@ +import pandas as pd +import numpy as np +from pyspark.sql import SparkSession +from pyspark.sql.functions import current_timestamp, to_utc_timestamp +from wine_quality.config import ProjectConfig + +# Load configuration +config = ProjectConfig.from_yaml(config_path="../../project_config.yml") +catalog_name = config.catalog_name +schema_name = config.schema_name +spark = SparkSession.builder.getOrCreate() + +# Load train and test sets +train_set = spark.table(f"{catalog_name}.{schema_name}.train_set").toPandas() +test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas() +combined_set = pd.concat([train_set, test_set], ignore_index=True) +existing_ids = set(int(id) for id in combined_set['id']) + +# Define function to create synthetic data without random state +def create_synthetic_data(df, num_rows=100): + synthetic_data = pd.DataFrame() + + for column in df.columns: + # Treat float and int differently + # if pd.api.types.is_numeric_dtype(df[column]) and column != 'id': + # mean, std = df[column].mean(), df[column].std() + # synthetic_data[column] = np.random.normal(mean, std, num_rows) + if pd.api.types.is_float_dtype(df[column]): + mean, std = df[column].mean(), df[column].std() + synthetic_data[column] = np.random.normal(mean, std, num_rows) + elif pd.api.types.is_integer_dtype(df[column]) and column != 'id': + mean, std = df[column].mean(), df[column].std() + synthetic_data[column] = np.random.normal(mean, std, num_rows).astype(int) + elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): + synthetic_data[column] = np.random.choice(df[column].unique(), num_rows, + p=df[column].value_counts(normalize=True)) + + elif pd.api.types.is_datetime64_any_dtype(df[column]): + min_date, max_date = df[column].min(), df[column].max() + if min_date < max_date: + synthetic_data[column] = pd.to_datetime( + np.random.randint(min_date.value, max_date.value, num_rows) + ) + else: + synthetic_data[column] = [min_date] * num_rows + + else: + synthetic_data[column] = np.random.choice(df[column], num_rows) + + # Making sure that generated IDs are unique and do not previously exist + new_ids = [] + i = max(existing_ids) + 1 if existing_ids else 1 + while len(new_ids) < num_rows: + if i not in existing_ids: + new_ids.append(i) # Id needs to be string, but leaving it as int to match train/test set. Will convert to string later. + #new_ids.append(str(i)) # Convert numeric ID to string + i += 1 + synthetic_data['id'] = new_ids + + return synthetic_data + +# Create synthetic data +synthetic_df = create_synthetic_data(combined_set) + +# Create source_data table manually using Create table like train_set +existing_schema = spark.table(f"{catalog_name}.{schema_name}.source_data").schema + +synthetic_spark_df = spark.createDataFrame(synthetic_df, schema=existing_schema) + +train_set_with_timestamp = synthetic_spark_df.withColumn( + "update_timestamp_utc", to_utc_timestamp(current_timestamp(), "UTC") +) + +# Append synthetic data as new data to source_data table +train_set_with_timestamp.write.mode("append").saveAsTable( + f"{catalog_name}.{schema_name}.source_data" +) \ No newline at end of file diff --git a/project_config.yml b/project_config.yml index 56383a2..03f6d7c 100644 --- a/project_config.yml +++ b/project_config.yml @@ -1,5 +1,6 @@ catalog_name: mlops_students schema_name: mahajan134 +pipeline_id: 5d52b992-0c14-4abe-a8d8-3880ad86fe93 parameters: learning_rate: 0.01 diff --git a/src/wine_quality/config.py b/src/wine_quality/config.py index 90edc93..136ed40 100644 --- a/src/wine_quality/config.py +++ b/src/wine_quality/config.py @@ -11,6 +11,7 @@ class ProjectConfig(BaseModel): schema_name: str parameters: Dict[str, Any] # Dictionary to hold model-related parameters ab_test: Dict[str, Any] # Dictionary to hold A/B test parameters + pipeline_id: str # pipeline id for data live tables @classmethod def from_yaml(cls, config_path: str): diff --git a/week5/deploy_model.py b/week5/deploy_model.py new file mode 100644 index 0000000..3454247 --- /dev/null +++ b/week5/deploy_model.py @@ -0,0 +1,58 @@ +""" +This script handles the deployment of a wine quality prediction model to a Databricks serving endpoint. +Key functionality: +- Loads project configuration from YAML +- Retrieves the model version from previous task values +- Updates the serving endpoint configuration with: + - Model registry reference + - Scale to zero capability + - Workload sizing + - Specific model version +The endpoint is configured for feature-engineered model serving with automatic scaling. +""" + +import yaml +import argparse +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.serving import ServedEntityInput +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession +from wine_quality.config import ProjectConfig + +parser = argparse.ArgumentParser() +parser.add_argument( + "--root_path", + action="store", + default=None, + type=str, + required=True, +) + +args = parser.parse_args() +root_path = args.root_path + +config_path = (f"{root_path}/project_config.yml") +# config_path = ("/Volumes/mlops_students/mahajan134/mlops_vol/project_config.yml") +config = ProjectConfig.from_yaml(config_path=config_path) + +spark = SparkSession.builder.getOrCreate() +dbutils = DBUtils(spark) + +model_version = dbutils.jobs.taskValues.get(taskKey="evaluate_model", key="model_version") + +workspace = WorkspaceClient() + +catalog_name = config.catalog_name +schema_name = config.schema_name + +workspace.serving_endpoints.update_config_and_wait( + name="wine-quality-model-serving-fe", + served_entities=[ + ServedEntityInput( + entity_name=f"{catalog_name}.{schema_name}.wine-quality-model-fe", + scale_to_zero_enabled=True, + workload_size="Small", + entity_version=model_version, + ) + ], +) \ No newline at end of file diff --git a/week5/evaluate_model.py b/week5/evaluate_model.py new file mode 100644 index 0000000..e1b719a --- /dev/null +++ b/week5/evaluate_model.py @@ -0,0 +1,181 @@ +""" +This script evaluates and compares a new house price prediction model against the currently deployed model. +Key functionality: +- Loads test data and performs feature engineering +- Generates predictions using both new and existing models +- Calculates and compares performance metrics (MAE and RMSE) +- Registers the new model if it performs better +- Sets task values for downstream pipeline steps + +The evaluation process: +1. Loads models from the serving endpoint +2. Prepares test data with feature engineering +3. Generates predictions from both models +4. Calculates error metrics +5. Makes registration decision based on MAE comparison +6. Updates pipeline task values with results +""" + +from databricks import feature_engineering +from databricks.sdk import WorkspaceClient +from pyspark.sql import SparkSession +from pyspark.dbutils import DBUtils + +from pyspark.sql import functions as F +from pyspark.sql.functions import col, lit, when +from pyspark.ml.evaluation import RegressionEvaluator +from datetime import datetime +import mlflow +import argparse +from pyspark.sql import functions as F +from pyspark.sql import DataFrame +from pyspark.ml.evaluation import RegressionEvaluator + +from wine_quality.config import ProjectConfig + + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--root_path", + action="store", + default=None, + type=str, + required=True, +) +parser.add_argument( + "--new_model_uri", + action="store", + default=None, + type=str, + required=True, +) + +parser.add_argument( + "--job_run_id", + action="store", + default=None, + type=str, + required=True, +) + +parser.add_argument( + "--git_sha", + action="store", + default=None, + type=str, + required=True, +) + + +args = parser.parse_args() +root_path = args.root_path +new_model_uri = args.new_model_uri +job_run_id = args.job_run_id +git_sha = args.git_sha + +config_path = (f"{root_path}/project_config.yml") +# config_path = ("/Volumes/mlops_test/house_prices/data/project_config.yml") +config = ProjectConfig.from_yaml(config_path=config_path) + +spark = SparkSession.builder.getOrCreate() +dbutils = DBUtils(spark) +workspace = WorkspaceClient() +fe = feature_engineering.FeatureEngineeringClient() + +mlflow.set_registry_uri("databricks-uc") +mlflow.set_tracking_uri("databricks") + +def get_latest_entity_version(serving_endpoint): + latest_version = 1 + for entity in serving_endpoint.config.served_entities: + version_int = int(entity.entity_version) + if version_int > latest_version: + latest_version = version_int + return latest_version + +# Extract configuration details +num_features = config.num_features +target = config.target +catalog_name = config.catalog_name +schema_name = config.schema_name + +# Define the serving endpoint +serving_endpoint_name = "wine-quality-model-serving-fe" +serving_endpoint = workspace.serving_endpoints.get(serving_endpoint_name) + +model_name = serving_endpoint.config.served_models[0].model_name +model_version = serving_endpoint.config.served_models[0].model_version +# latest_entity = get_latest_entity_version(serving_endpoint) +# print(latest_entity) +# model_name = serving_endpoint.config.served_entities[latest_entity].model_name +# model_version = serving_endpoint.config.served_entities[latest_entity].model_version + +previous_model_uri = f"models:/{model_name}/{model_version}" +cur_model_name = "wine-quality-model-fe" + + +# Load test set and create additional features in Spark DataFrame +test_set = spark.table(f"{catalog_name}.{schema_name}.test_set") +# Cast residual_sugar to int for the function inputs +test_set = test_set.withColumn("residual_sugar", test_set["residual_sugar"].cast("int")) +test_set = test_set.withColumn("is_sweet_indicator", when(col("residual_sugar") > 40, lit('1')).otherwise(lit('0'))) +test_set = test_set.withColumn("is_sweet_indicator", test_set["residual_sugar"].cast("int")) +test_set = test_set.withColumn("Id", test_set["id"].cast("string")) + + +# Select the necessary columns for prediction and target +X_test_spark = test_set.select(num_features + ["is_sweet_indicator", "Id"]) +y_test_spark = test_set.select("Id", target) + + +# Generate predictions from both models +predictions_previous = fe.score_batch(model_uri=previous_model_uri, df=X_test_spark) +predictions_new = fe.score_batch(model_uri=new_model_uri, df=X_test_spark) + +predictions_new = predictions_new.withColumnRenamed("prediction", "prediction_new") +predictions_old = predictions_previous.withColumnRenamed("prediction", "prediction_old") +test_set = test_set.select("Id", "quality") + +# Join the DataFrames on the 'id' column +df = test_set \ + .join(predictions_new, on="Id") \ + .join(predictions_old, on="Id") + +# Calculate the absolute error for each model +df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) +df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) + +# Calculate the absolute error for each model +df = df.withColumn("error_new", F.abs(df["quality"] - df["prediction_new"])) +df = df.withColumn("error_old", F.abs(df["quality"] - df["prediction_old"])) + +# Calculate the Mean Absolute Error (MAE) for each model +mae_new = df.agg(F.mean("error_new")).collect()[0][0] +mae_old = df.agg(F.mean("error_old")).collect()[0][0] + +# Calculate the Root Mean Squared Error (RMSE) for each model +evaluator = RegressionEvaluator(labelCol="quality", predictionCol="prediction_new", metricName="rmse") +rmse_new = evaluator.evaluate(df) + +evaluator.setPredictionCol("prediction_old") +rmse_old = evaluator.evaluate(df) + +# Compare models based on MAE and RMSE +print(f"MAE for New Model: {mae_new}") +print(f"MAE for Old Model: {mae_old}") + +if mae_new < mae_old: + print("New model is better based on MAE.") + model_version = mlflow.register_model( + model_uri=new_model_uri, + name=f"{catalog_name}.{schema_name}.{cur_model_name}", + tags={"git_sha": f"{git_sha}", + "job_run_id": job_run_id}) + + print("New model registered with version:", model_version.version) + dbutils.jobs.taskValues.set(key="model_version", value=model_version.version) + dbutils.jobs.taskValues.set(key="model_update", value=1) +else: + print("Old model is better based on MAE.") + dbutils.jobs.taskValues.set(key="model_update", value=0) \ No newline at end of file diff --git a/week5/preprocess.py b/week5/preprocess.py new file mode 100644 index 0000000..83712ef --- /dev/null +++ b/week5/preprocess.py @@ -0,0 +1,127 @@ +""" +This script handles data ingestion and feature table updates for a house price prediction system. + +Key functionality: +- Loads the source dataset and identifies new records for processing +- Splits new records into train and test sets based on timestamp +- Updates existing train and test tables with new data +- Inserts the latest feature values into the feature table for serving +- Triggers and monitors pipeline updates for online feature refresh +- Sets task values to coordinate pipeline orchestration + +Workflow: +1. Load source dataset and retrieve recent records with updated timestamps. +2. Split new records into train and test sets (80-20 split). +3. Append new train and test records to existing train and test tables. +4. Insert the latest feature data into the feature table for online serving. +5. Trigger a pipeline update and monitor its status until completion. +6. Set a task value indicating whether new data was processed. +""" + +import yaml +import argparse +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, max as spark_max +import yaml +from databricks.sdk import WorkspaceClient +import time +from wine_quality.config import ProjectConfig + + +workspace = WorkspaceClient() + +parser = argparse.ArgumentParser() +parser.add_argument( + "--root_path", + action="store", + default=None, + type=str, + required=True, +) + +args = parser.parse_args() +root_path = args.root_path +print(f"Root path: {root_path}") +config_path = (f"{root_path}/project_config.yml") +config = ProjectConfig.from_yaml(config_path=config_path) +pipeline_id = config.pipeline_id + +spark = SparkSession.builder.getOrCreate() +dbutils = DBUtils(spark) + +catalog_name = config.catalog_name +schema_name = config.schema_name + +# Load source_data table - mimicking arrival of new data +""" Source data is synthetic data created using test and train set tables which are already processed so source data is also processed. +If source data was 'raw' data then we should run the data_preprocessor class to clean the raw data before appending to test train tables.""" +source_data = spark.table(f"{catalog_name}.{schema_name}.source_data") + +# Get max update timestamps from existing data +max_train_timestamp = spark.table(f"{catalog_name}.{schema_name}.train_set") \ + .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) \ + .collect()[0]["max_update_timestamp"] + +max_test_timestamp = spark.table(f"{catalog_name}.{schema_name}.test_set") \ + .select(spark_max("update_timestamp_utc").alias("max_update_timestamp")) \ + .collect()[0]["max_update_timestamp"] + +latest_timestamp = max(max_train_timestamp, max_test_timestamp) + +# Filter source_data for rows with update_timestamp_utc greater than the latest_timestamp +new_data = source_data.filter(col("update_timestamp_utc") > latest_timestamp) + +# Split the new data into train and test sets +new_data_train, new_data_test = new_data.randomSplit([0.8, 0.2], seed=42) + +# Update train_set and test_set tables +new_data_train.write.mode("append").saveAsTable(f"{catalog_name}.{schema_name}.train_set") +new_data_test.write.mode("append").saveAsTable(f"{catalog_name}.{schema_name}.test_set") + +# Verify affected rows count for train and test +affected_rows_train = new_data_train.count() +affected_rows_test = new_data_test.count() + +#write into feature table; update online table +if affected_rows_train > 0 or affected_rows_test > 0 : + spark.sql(f""" + WITH max_timestamp AS ( + SELECT MAX(update_timestamp_utc) AS max_update_timestamp + FROM {catalog_name}.{schema_name}.train_set + ) + INSERT INTO {catalog_name}.{schema_name}.wine_features + SELECT cast(id as string) as Id, volatile_acidity, alcohol, sulphates + FROM {catalog_name}.{schema_name}.train_set + WHERE update_timestamp_utc == (SELECT max_update_timestamp FROM max_timestamp) +""") + spark.sql(f""" + WITH max_timestamp AS ( + SELECT MAX(update_timestamp_utc) AS max_update_timestamp + FROM {catalog_name}.{schema_name}.test_set + ) + INSERT INTO {catalog_name}.{schema_name}.wine_features + SELECT cast(id as string) as Id, volatile_acidity, alcohol, sulphates + FROM {catalog_name}.{schema_name}.test_set + WHERE update_timestamp_utc == (SELECT max_update_timestamp FROM max_timestamp) +""") + refreshed = 1 + update_response = workspace.pipelines.start_update( + pipeline_id=pipeline_id, full_refresh=False) + while True: + update_info = workspace.pipelines.get_update(pipeline_id=pipeline_id, + update_id=update_response.update_id) + state = update_info.update.state.value + if state == 'COMPLETED': + break + elif state in ['FAILED', 'CANCELED']: + raise SystemError("Online table failed to update.") + elif state == 'WAITING_FOR_RESOURCES': + print("Pipeline is waiting for resources.") + else: + print(f"Pipeline is in {state} state.") + time.sleep(30) +else: + refreshed = 0 + +dbutils.jobs.taskValues.set(key="refreshed", value=refreshed) \ No newline at end of file diff --git a/week5/train_model.py b/week5/train_model.py new file mode 100644 index 0000000..c979c50 --- /dev/null +++ b/week5/train_model.py @@ -0,0 +1,177 @@ +""" +This script trains a LightGBM model for house price prediction with feature engineering. +Key functionality: +- Loads training and test data from Databricks tables +- Performs feature engineering using Databricks Feature Store +- Creates a pipeline with preprocessing and LightGBM regressor +- Tracks the experiment using MLflow +- Logs model metrics, parameters and artifacts +- Handles feature lookups and custom feature functions +- Outputs model URI for downstream tasks + +The model uses numerical features, including a custom calculated wine sweetness feature. +""" + +from databricks import feature_engineering +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession +from databricks.sdk import WorkspaceClient +import mlflow +import argparse +from datetime import datetime +from lightgbm import LGBMRegressor +from mlflow.models import infer_signature +from pyspark.sql import functions as F +from sklearn.compose import ColumnTransformer +from sklearn.impute import SimpleImputer +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from databricks.feature_engineering import FeatureFunction, FeatureLookup +from wine_quality.config import ProjectConfig + + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--root_path", + action="store", + default=None, + type=str, + required=True, +) +parser.add_argument( + "--git_sha", + action="store", + default=None, + type=str, + required=True, +) +parser.add_argument( + "--job_run_id", + action="store", + default=None, + type=str, + required=True, +) + +args = parser.parse_args() +root_path = args.root_path +git_sha = args.git_sha +job_run_id = args.job_run_id +cur_experiment_name = "/Shared/wine-quality-fe" +cur_model_artifact_path = "wine-quality-model-fe" + + +config_path = (f"{root_path}/project_config.yml") +# config_path = ("/Volumes/mlops_students/mahajan134/mlops_vol/project_config.yml") + +config = ProjectConfig.from_yaml(config_path=config_path) + +# Initialize the Databricks session and clients +spark = SparkSession.builder.getOrCreate() +dbutils = DBUtils(spark) +workspace = WorkspaceClient() +fe = feature_engineering.FeatureEngineeringClient() + +mlflow.set_registry_uri("databricks-uc") +mlflow.set_tracking_uri("databricks") + +# Extract configuration details +num_features = config.num_features +target = config.target +parameters = config.parameters +catalog_name = config.catalog_name +schema_name = config.schema_name + + +# Define table names and function name +feature_table_name = f"{catalog_name}.{schema_name}.wine_features" +function_name = f"{catalog_name}.{schema_name}.calculate_wine_sweetness" + +# Load training and test sets +train_set = spark.table(f"{catalog_name}.{schema_name}.train_set").drop("volatile_acidity", "alcohol", "sulphates") +test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas() + +# Cast residual_sugar to int for the function input +train_set = train_set.withColumn("residual_sugar", train_set["residual_sugar"].cast("int")) +train_set = train_set.withColumn("Id", train_set["id"].cast("string")) + + +# Feature engineering setup +training_set = fe.create_training_set( + df=train_set, + label=target, + feature_lookups=[ + FeatureLookup( + table_name=feature_table_name, + feature_names=["volatile_acidity", "alcohol", "sulphates"], + lookup_key="Id", + ), + FeatureFunction( + udf_name=function_name, + output_name="is_sweet_indicator", + input_bindings={"residual_sugar_content": "residual_sugar"}, + ), + ], + exclude_columns=["update_timestamp_utc"] +) + +# Load feature-engineered DataFrame +training_df = training_set.load_df().toPandas() + +# Calculate is_sweet_indicator for training and test set +test_set["is_sweet_indicator"] = test_set["residual_sugar"].apply(lambda x: 1 if x > 40 else 0) + +# Split features and target +X_train = training_df[num_features + ["is_sweet_indicator"]] +y_train = training_df[target] +X_test = test_set[num_features + ["is_sweet_indicator"]] +y_test = test_set[target] + +# Setup preprocessing and model pipeline +preprocessor = ColumnTransformer( + transformers=[("num", SimpleImputer(strategy="mean"), num_features), ("std", StandardScaler(), num_features)], + remainder="passthrough", +) + +# Create the pipeline with preprocessing and the LightGBM regressor +pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("regressor", LGBMRegressor(**parameters))]) + + +mlflow.set_experiment(experiment_name=cur_experiment_name) + +with mlflow.start_run(tags={"branch": "week5", + "git_sha": f"{git_sha}", + "job_run_id": job_run_id}) as run: + run_id = run.info.run_id + pipeline.fit(X_train, y_train) + y_pred = pipeline.predict(X_test) + + # Calculate and print metrics + mse = mean_squared_error(y_test, y_pred) + mae = mean_absolute_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) + print(f"Mean Squared Error: {mse}") + print(f"Mean Absolute Error: {mae}") + print(f"R2 Score: {r2}") + + # Log model parameters, metrics, and model + mlflow.log_param("model_type", "LightGBM with preprocessing") + mlflow.log_params(parameters) + mlflow.log_metric("mse", mse) + mlflow.log_metric("mae", mae) + mlflow.log_metric("r2_score", r2) + signature = infer_signature(model_input=X_train, model_output=y_pred) + + # Log model with feature engineering + fe.log_model( + model=pipeline, + flavor=mlflow.sklearn, + artifact_path=cur_model_artifact_path, + training_set=training_set, + signature=signature, + ) + +model_uri=f'runs:/{run_id}/{cur_model_artifact_path}' +dbutils.jobs.taskValues.set(key="new_model_uri", value=model_uri) \ No newline at end of file