-
Notifications
You must be signed in to change notification settings - Fork 1
Feature/week2 #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
ea58f16
refactor(project_configs.yml-config.py): update project config
Garett601 a44a2eb
refactor(utils.py,-rf_model.py-data_preprocessor.py-processed_data.py…
Garett601 0b2660d
test(test_rf_model.py-test_data_preprocessor.py): update tests for re…
Garett601 4573671
refactor(week1/week_1.ipynb): update repo structure
Garett601 06445d5
refactor(week_1.py): refactor week 1 notebook
Garett601 507c239
refactor(week_1.py): refactor week 1 notebook
Garett601 b35f12f
feat(data_preprocessor.py): add save_to_catalog method
Garett601 6a62cea
style(data_preprocessor.py): update docstring
Garett601 b2e8238
test(test_data_preprocessor.py): add test for save_to_catalog method
Garett601 30d57af
feat(01_prepare_data.py): first task for week2 homework
Garett601 1107246
feat(02_mlflow_experiment.py): week 2 notebook 2 code
Garett601 444e285
feat(03_log_and_register_model.py): week2 notebook 3 code
Garett601 65d3bf3
feat(04_log_and_register_custom_model.py): week2 notebook 4 code
Garett601 a2c9e31
chore: save latest code - why do I keep doing this?
Garett601 65b5fe8
refactor(01_prepare_dataset.py): maintain DateTime column when saving…
Garett601 e7f7120
feat(05_log_and_register_fe_model.py): week2 notebook 5 code
Garett601 c612a55
docs(README.md): Update readme with changes
Garett601 7aeeb1d
ci(ci.yml): add PySpark to CI run for tests
Garett601 beccea3
ci(ci.yml): add databricks-sdk to testing step
Garett601 6a8146a
ci(ci.yml): add pyspark AND databricks-sdk to tests step
Garett601 f25a887
test(test_data_preprocessor.py): mock pysoark.sql.functions
Garett601 ea29214
test(test_data_preprocessor.py): mock all spark behaviour
Garett601 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,3 +98,9 @@ dmypy.json | |
|
||
.databricks | ||
.ruff_cache/ | ||
|
||
# JSON files | ||
*.json | ||
|
||
# MLFlow | ||
mlruns/ |
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Databricks notebook source | ||
from power_consumption.preprocessing.data_preprocessor import DataProcessor | ||
from power_consumption.model.rf_model import ConsumptionModel | ||
from power_consumption.utils import visualise_results, plot_actual_vs_predicted, plot_feature_importance | ||
from power_consumption.config import Config | ||
from pyspark.sql import SparkSession | ||
|
||
spark = SparkSession.builder.getOrCreate() | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# COMMAND ---------- | ||
|
||
config = Config.from_yaml("../../configs/project_configs.yml") | ||
|
||
# COMMAND ---------- | ||
catalog_name = config.catalog_name | ||
schema_name = config.schema_name | ||
raw_data_table = config.dataset.raw_data_table | ||
# COMMAND ---------- | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data_spark = spark.table(f"{catalog_name}.{schema_name}.{raw_data_table}") | ||
# COMMAND ---------- | ||
data_pandas = data_spark.toPandas() | ||
# COMMAND ---------- | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data_processor = DataProcessor(config, data_pandas) | ||
# COMMAND ---------- | ||
data_processor.preprocess_data() | ||
# COMMAND ---------- | ||
train_set, test_set = data_processor.split_data() | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
target_columns = config.target.target | ||
feature_columns = config.processed_features.num_features + config.processed_features.cat_features | ||
|
||
X_train = train_set[feature_columns] | ||
y_train = train_set[target_columns] | ||
X_test = test_set[feature_columns] | ||
y_test = test_set[target_columns] | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
model = ConsumptionModel(config) | ||
model.train(X_train, y_train) | ||
|
||
# COMMAND ---------- | ||
|
||
# Make predictions and evaluate the model | ||
y_pred = model.predict(X_test) | ||
mse, r2 = model.evaluate(X_test, y_test) | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
|
||
# Visualize results as time series | ||
visualise_results(y_test, y_pred, target_columns) | ||
|
||
# COMMAND ---------- | ||
|
||
# Get feature importance | ||
feature_importance, feature_names = model.get_feature_importance() | ||
# COMMAND ---------- | ||
|
||
# Plot actual vs predicted values | ||
plot_actual_vs_predicted(y_test.values, y_pred, target_columns) | ||
|
||
# COMMAND ---------- | ||
|
||
# Plot feature importance | ||
plot_feature_importance(feature_importance, feature_names, top_n=15) | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Databricks notebook source | ||
from power_consumption.preprocessing.data_preprocessor import DataProcessor | ||
from power_consumption.config import Config | ||
from pyspark.sql import SparkSession | ||
|
||
spark = SparkSession.builder.getOrCreate() | ||
|
||
|
||
# COMMAND ---------- | ||
|
||
config = Config.from_yaml("../../configs/project_configs.yml") | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# COMMAND ---------- | ||
catalog_name = config.catalog_name | ||
schema_name = config.schema_name | ||
raw_data_table = config.dataset.raw_data_table | ||
# COMMAND ---------- | ||
data_spark = spark.table(f"{catalog_name}.{schema_name}.{raw_data_table}") | ||
# COMMAND ---------- | ||
data_pandas = data_spark.toPandas() | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
data_processor = DataProcessor(config, data_pandas) | ||
# COMMAND ---------- | ||
data_processor.preprocess_data() | ||
# COMMAND ---------- | ||
train_set, test_set = data_processor.split_data() | ||
# COMMAND ---------- | ||
train_set.reset_index(inplace=True) | ||
test_set.reset_index(inplace=True) | ||
# COMMAND ---------- | ||
data_processor.save_to_catalog(train_set=train_set, test_set=test_set, spark=spark) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Databricks notebook source | ||
import json | ||
|
||
import mlflow | ||
|
||
mlflow.set_tracking_uri("databricks") | ||
|
||
mlflow.set_experiment(experiment_name="/Shared/power-consumption") | ||
mlflow.set_experiment_tags({"repository_name": "power-consumption"}) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# COMMAND ---------- | ||
experiments = mlflow.search_experiments( | ||
filter_string="tags.repository_name='power-consumption'" | ||
) | ||
|
||
print(experiments) | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
with open("mlflow_experiment.json", "w") as json_file: | ||
json.dump(experiments[0].__dict__, json_file, indent=4) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# COMMAND ---------- | ||
with mlflow.start_run( | ||
run_name="test-run", | ||
tags={ | ||
"git_sha": "30d57afb2efca70cede3061d00f2a553c2b4779b" | ||
} | ||
) as run: | ||
mlflow.log_params({"type": "demo"}) | ||
mlflow.log_metrics( | ||
{ | ||
"metric_1": 1.0, | ||
"metric_2": 2.0 | ||
} | ||
) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
run_id = mlflow.search_runs( | ||
experiment_names=["/Shared/power-consumption"], | ||
filter_string="tags.git_sha='30d57afb2efca70cede3061d00f2a553c2b4779b'", | ||
).run_id[0] | ||
run_info = mlflow.get_run(run_id=f"{run_id}").to_dictionary() | ||
print(run_info) | ||
|
||
# COMMAND ---------- | ||
with open("run_info.json", "w") as json_file: | ||
json.dump(run_info, json_file, indent=4) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# COMMAND ---------- | ||
print(run_info["data"]["metrics"]) | ||
|
||
# COMMAND ---------- | ||
print(run_info["data"]["params"]) | ||
# COMMAND ---------- | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Databricks notebook source | ||
import mlflow | ||
from mlflow.models import infer_signature | ||
|
||
from pyspark.sql import SparkSession | ||
from power_consumption.config import Config | ||
|
||
from sklearn.preprocessing import OneHotEncoder | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.pipeline import Pipeline | ||
from lightgbm import LGBMRegressor | ||
from sklearn.multioutput import MultiOutputRegressor | ||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score | ||
|
||
mlflow.set_tracking_uri("databricks") | ||
mlflow.set_registry_uri("databricks-uc") | ||
# COMMAND ---------- | ||
config = Config.from_yaml("../../configs/project_configs.yml") | ||
# COMMAND ---------- | ||
num_features = config.processed_features.num_features | ||
cat_features = config.processed_features.cat_features | ||
target = config.target.target | ||
parameters = config.hyperparameters.__dict__ | ||
|
||
catalog_name = config.catalog_name | ||
schema_name = config.schema_name | ||
# COMMAND ---------- | ||
spark = SparkSession.builder.getOrCreate() | ||
|
||
train_set_spark = spark.table(f"{catalog_name}.{schema_name}.train_set") | ||
train_set = spark.table(f"{catalog_name}.{schema_name}.train_set").toPandas() | ||
test_set = spark.table(f"{catalog_name}.{schema_name}.test_set").toPandas() | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- | ||
X_train = train_set[num_features + cat_features] | ||
y_train = train_set[target] | ||
|
||
X_test = test_set[num_features + cat_features] | ||
y_test = test_set[target] | ||
# COMMAND ---------- | ||
# Define the preprocessor for categorical features | ||
preprocessor = ColumnTransformer( | ||
transformers=[('cat', OneHotEncoder(handle_unknown='ignore'), cat_features)], | ||
remainder='passthrough' | ||
) | ||
|
||
# Create the pipeline with preprocessing and the multi-output LightGBM regressor | ||
pipeline = Pipeline(steps=[ | ||
('preprocessor', preprocessor), | ||
('regressor', MultiOutputRegressor(LGBMRegressor(**parameters))) | ||
]) | ||
# COMMAND ---------- | ||
mlflow.set_experiment(experiment_name='/Shared/power-consumption') | ||
git_sha = "30d57afb2efca70cede3061d00f2a553c2b4779b" | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Start an MLflow run to track the training process | ||
with mlflow.start_run( | ||
tags={"git_sha": f"{git_sha}", | ||
"branch": "feature/week2"}, | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) as run: | ||
run_id = run.info.run_id | ||
|
||
pipeline.fit(X_train, y_train) | ||
y_pred = pipeline.predict(X_test) | ||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Evaluate the model performance | ||
mse = mean_squared_error(y_test, y_pred) | ||
mae = mean_absolute_error(y_test, y_pred) | ||
r2 = r2_score(y_test, y_pred) | ||
|
||
print(f"Mean Squared Error: {mse}") | ||
print(f"Mean Absolute Error: {mae}") | ||
print(f"R2 Score: {r2}") | ||
|
||
# Log parameters, metrics, and the model to MLflow | ||
mlflow.log_param("model_type", "LightGBM with preprocessing") | ||
mlflow.log_params(parameters) | ||
mlflow.log_metric("mse", mse) | ||
mlflow.log_metric("mae", mae) | ||
mlflow.log_metric("r2_score", r2) | ||
signature = infer_signature(model_input=X_train, model_output=y_pred) | ||
|
||
dataset = mlflow.data.from_spark( | ||
train_set_spark, table_name=f"{catalog_name}.{schema_name}.train_set", | ||
version="0") | ||
mlflow.log_input(dataset, context="training") | ||
|
||
mlflow.sklearn.log_model( | ||
sk_model=pipeline, | ||
artifact_path="lightgbm-pipeline-model", | ||
signature=signature | ||
) | ||
# COMMAND ---------- | ||
model_version = mlflow.register_model( | ||
model_uri=f'runs:/{run_id}/lightgbm-pipeline-model', | ||
name=f"{catalog_name}.{schema_name}.power_consumption_model", | ||
tags={"git_sha": f"{git_sha}"}) | ||
|
||
# COMMAND ---------- | ||
run = mlflow.get_run(run_id) | ||
dataset_info = run.inputs.dataset_inputs[0].dataset | ||
dataset_source = mlflow.data.get_source(dataset_info) | ||
dataset_source.load() | ||
|
||
Garett601 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# COMMAND ---------- |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.