-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: train and register model/experiment #2
Conversation
WalkthroughThe changes introduce new functionalities for managing MLflow experiments and data processing within a Databricks environment. A new script, Changes
Poem
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
🧹 Outside diff range and nitpick comments (7)
notebooks/mlflow_exp.py (1)
1-44
: Consider refactoring into a proper Python module.
The current implementation as a Databricks notebook with multiple commands reduces maintainability and testability. Consider:
- Moving the code to a proper Python module
- Creating a class to encapsulate MLflow operations
- Adding proper logging
- Including unit tests
Example structure:
# src/mlops_with_databricks/mlflow/experiment.py
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
@dataclass
class MLflowConfig:
workspace_url: str
experiment_name: str
repository_name: str
class MLflowExperiment:
def __init__(self, config: MLflowConfig):
self.config = config
self._setup_mlflow()
def _setup_mlflow(self) -> None:
# Setup implementation
pass
def create_run(self, run_name: str, description: str) -> None:
# Run creation implementation
pass
def get_run_info(self, git_sha: str) -> Dict[str, Any]:
# Run retrieval implementation
pass
src/mlops_with_databricks/data_preprocessing/dataclasses.py (2)
65-71
: Consider moving Databricks configuration to environment variables.
Hardcoding environment-specific values like workspace URL and schema details directly in the code makes it difficult to manage different environments (dev, staging, prod) and poses security risks.
Consider:
- Moving these values to environment variables or a configuration file
- Using a configuration management system
- Implementing a configuration loader
Example implementation:
@dataclass
class DatabricksConfig:
"""Dataclass for the Databricks configuration."""
workspace_url: str = os.getenv('DATABRICKS_WORKSPACE_URL')
catalog_name: str = os.getenv('DATABRICKS_CATALOG_NAME', 'mlops_students')
schema_name: str = os.getenv('DATABRICKS_SCHEMA_NAME')
80-80
: Document hyperparameter choices and add validation.
While the hyperparameter values seem reasonable, it would be helpful to:
- Document why these specific values were chosen
- Add validation to ensure parameters are within acceptable ranges
Consider adding a validation function:
def validate_lightgbm_config(config: LightGBMConfig) -> None:
"""Validate LightGBM hyperparameters."""
if not (0 < config['learning_rate'] <= 1):
raise ValueError("learning_rate must be between 0 and 1")
if config['n_estimators'] < 1:
raise ValueError("n_estimators must be positive")
if config['max_depth'] < 1:
raise ValueError("max_depth must be positive")
src/mlops_with_databricks/training/train.py (2)
39-39
: Remove unnecessary debug print statement
The print(train_set.head())
statement is likely used for debugging purposes. It can be removed to clean up the console output and improve performance.
Apply this diff to remove the print statement:
- print(train_set.head())
67-67
: Retrieve Git SHA dynamically instead of hardcoding
Hardcoding the git_sha
might lead to inconsistencies if the code changes and the SHA is not updated. Retrieve the Git SHA programmatically to ensure it always reflects the current commit.
Apply this diff to retrieve the Git SHA dynamically:
+ import subprocess
- git_sha = "ffa63b430205ff7"
+ git_sha = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).strip().decode()
Note: Ensure that the execution environment has access to the Git repository and git
command-line tool.
src/mlops_with_databricks/data_preprocessing/preprocess.py (2)
35-37
: Remove redundant assignments in from_pandas
method
The attributes X
, y
, and preprocessor
are already initialized to None
in the __init__
method. Reassigning them to None
in from_pandas
is redundant.
Apply this diff to remove the redundant assignments:
instance = cls()
-instance.X = None
-instance.y = None
-instance.preprocessor = None
instance.df = pandas_df
136-140
: Avoid hardcoding configurations and file paths
The profile name and the file path are hardcoded, which can lead to issues when running the code in different environments. Consider using configuration files, environment variables, or parameters to make the code more flexible and maintainable.
Apply this diff to parameterize the profile and file path:
-spark = DatabricksSession.builder.profile("dbc-643c4c2b-d6c9").getOrCreate()
+import os
+
+profile_name = os.getenv("DATABRICKS_PROFILE", "default_profile") # Replace with your default profile
+spark = DatabricksSession.builder.profile(profile_name).getOrCreate()
-df = spark.read.csv(
- "/Volumes/mlops_students/armak58/data/ad_click_dataset.csv", header=True, inferSchema=True
-).toPandas()
+data_path = os.getenv("DATA_PATH", "/path/to/ad_click_dataset.csv") # Replace with your default data path
+df = spark.read.csv(data_path, header=True, inferSchema=True).toPandas()
This allows the profile name and data path to be configurable through environment variables, improving the portability of the code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
- notebooks/mlflow_exp.py (1 hunks)
- src/mlops_with_databricks/data_preprocessing/dataclasses.py (2 hunks)
- src/mlops_with_databricks/data_preprocessing/preprocess.py (3 hunks)
- src/mlops_with_databricks/training/train.py (1 hunks)
🔇 Additional comments (2)
src/mlops_with_databricks/data_preprocessing/dataclasses.py (1)
37-62
: LGTM! Well-structured feature configuration.
The ProcessedAdClickDataConfig
class provides a clear and comprehensive configuration for processed features, with proper one-hot encoding representation for categorical variables.
src/mlops_with_databricks/training/train.py (1)
95-95
:
Log the best pipeline from GridSearchCV
You're currently logging the original pipeline
, which doesn't include the best parameters found during hyperparameter tuning. Log best_pipeline
to ensure the model artifact uses the optimized parameters.
Apply this diff to log the best pipeline:
- mlflow.sklearn.log_model(sk_model=pipeline, artifact_path="lightgbm-pipeline-model", signature=signature)
+ mlflow.sklearn.log_model(sk_model=best_pipeline, artifact_path="lightgbm-pipeline-model", signature=signature)
Likely invalid or redundant comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Outside diff range and nitpick comments (4)
src/mlops_with_databricks/data_preprocessing/dataclasses.py (2)
52-58
: Consider externalizing Databricks configuration.
Hardcoding Databricks workspace URL and schema details directly in the code makes it less portable and harder to maintain. Consider:
- Moving these values to environment variables or a configuration file
- Making the schema name configurable to support different users
Example approach using environment variables:
@dataclass
class DatabricksConfig:
"""Dataclass for the Databricks configuration."""
workspace_url: str = os.getenv('DATABRICKS_WORKSPACE_URL')
catalog_name: str = os.getenv('DATABRICKS_CATALOG_NAME')
schema_name: str = os.getenv('DATABRICKS_SCHEMA_NAME')
67-67
: Document hyperparameter choices and consider validation.
While the hyperparameter values look reasonable, it would be helpful to:
- Document why these specific values were chosen
- Verify if these parameters are optimal for your dataset size and characteristics
Add a comment explaining the parameter choices:
# Default hyperparameters based on initial experiments:
# - learning_rate: Conservative value to ensure stable training
# - n_estimators: Balanced between model performance and training time
# - max_depth: Prevents overfitting while allowing sufficient model complexity
light_gbm_config = LightGBMConfig(learning_rate=0.001, n_estimators=200, max_depth=10)
src/mlops_with_databricks/training/train.py (1)
118-121
: Remove or document commented-out code
Either remove the commented-out code or add a comment explaining why it's kept and when it might be needed.
src/mlops_with_databricks/data_preprocessing/preprocess.py (1)
51-51
: Clarify the commented-out OneHotEncoder.
The commented-out OneHotEncoder suggests incomplete changes. If it's no longer needed, remove it. If it's temporarily disabled, add a TODO comment explaining why.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- src/mlops_with_databricks/data_preprocessing/dataclasses.py (2 hunks)
- src/mlops_with_databricks/data_preprocessing/preprocess.py (3 hunks)
- src/mlops_with_databricks/training/train.py (1 hunks)
🔇 Additional comments (6)
src/mlops_with_databricks/data_preprocessing/dataclasses.py (1)
61-64
: LGTM! TypedDict implementation looks correct.
The implementation properly defines types for the LightGBM parameters without default values, addressing the previous review comments.
src/mlops_with_databricks/training/train.py (2)
22-28
: LGTM: Well-structured git info retrieval with proper error handling
The function properly handles potential git command failures and provides fallback values.
75-75
: Reconsider using n_jobs=-1 in Databricks environment
Using n_jobs=-1
might not be optimal in a Databricks environment as it can interfere with the cluster's resource management.
Consider setting this to a specific number based on the cluster configuration or making it configurable:
-grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="f1", n_jobs=-1)
+grid_search = GridSearchCV(pipeline, param_grid, cv=4, scoring="f1", n_jobs=DatabricksConfig.grid_search_jobs)
src/mlops_with_databricks/data_preprocessing/preprocess.py (3)
6-8
: LGTM: New imports are properly organized.
The added imports support Databricks integration and local configurations, following Python's import organization conventions.
Also applies to: 14-14
20-24
: LGTM: Improved class initialization design.
The separation of concerns between initialization and data loading, along with the addition of the from_pandas
factory method, enhances flexibility and follows good Python practices.
Also applies to: 30-35
61-63
: Verify target variable handling in preprocessing.
The current approach of adding the target variable back to preprocessed features might lead to data leakage. Consider keeping features and target separate throughout the pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve to get you unblocked. Please pay attention to the profile & SparkSession
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
src/mlops_with_databricks/data_preprocessing/preprocess.py (3)
7-8
: Remove commented-out import.The commented-out import of
DatabricksSession
should be removed if it's no longer needed. If this is a temporary change, please add a comment explaining why it's commented out.
52-52
: Document or remove the commented-out OneHotEncoder.If OneHotEncoder is intentionally disabled, please add a comment explaining why. If it's no longer needed, remove the commented line.
100-127
: Add error handling and validation to save_to_catalog method.While the implementation is good, consider adding:
- Error handling for Spark operations
- Validation of input DataFrames
- Logging of successful/failed operations
Example implementation:
def save_to_catalog(self, train_set: pd.DataFrame, test_set: pd.DataFrame, spark: SparkSession): """Save the train and test sets into Databricks tables. Args: train_set: Training data test_set: Test data spark: Active SparkSession Raises: ValueError: If input DataFrames are empty SparkException: If table creation fails """ if train_set.empty or test_set.empty: raise ValueError("Input DataFrames cannot be empty") try: train_set_with_timestamp = spark.createDataFrame(train_set).withColumn( "update_timestamp_utc", to_utc_timestamp(current_timestamp(), "UTC") ) # ... rest of the implementation ... logging.info("Successfully saved datasets to catalog") except Exception as e: logging.error(f"Failed to save datasets: {str(e)}") raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
notebooks/mlflow_exp.py
(1 hunks)src/mlops_with_databricks/data_preprocessing/preprocess.py
(3 hunks)src/mlops_with_databricks/training/train.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- notebooks/mlflow_exp.py
- src/mlops_with_databricks/training/train.py
🔇 Additional comments (1)
src/mlops_with_databricks/data_preprocessing/preprocess.py (1)
21-36
: LGTM! Good separation of concerns.
The refactoring improves the class design by:
- Separating data loading from initialization
- Adding flexibility with the
from_pandas
factory method - Maintaining clear state management
if __name__ == "__main__": | ||
spark = SparkSession.builder.getOrCreate() # DatabricksSession.builder.profile("dbc-643c4c2b-d6c9").getOrCreate() | ||
|
||
df = spark.read.csv( | ||
"/Volumes/mlops_students/armak58/data/ad_click_dataset.csv", header=True, inferSchema=True | ||
).toPandas() | ||
|
||
data_processor = DataProcessor.from_pandas(df) | ||
data_processor.preprocess_data() | ||
train_set, test_set = data_processor.split_data() | ||
data_processor.save_to_catalog(train_set=train_set, test_set=test_set, spark=spark) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve main execution block robustness.
The main execution block has several areas for improvement:
- Hardcoded paths should be moved to configuration
- Large DataFrame conversion to pandas could cause memory issues
- Missing error handling and logging
Consider this improved implementation:
if __name__ == "__main__":
try:
# Get configuration from environment or config file
data_path = os.getenv("DATA_PATH") or "/Volumes/mlops_students/armak58/data/ad_click_dataset.csv"
spark = SparkSession.builder.getOrCreate()
# Process data in chunks to avoid memory issues
df = spark.read.csv(
data_path,
header=True,
inferSchema=True
)
# Optional: Apply any necessary filtering/sampling before converting to pandas
if df.count() > 1_000_000: # Adjust threshold as needed
df = df.sample(False, 0.5)
pandas_df = df.toPandas()
data_processor = DataProcessor.from_pandas(pandas_df)
data_processor.preprocess_data()
train_set, test_set = data_processor.split_data()
data_processor.save_to_catalog(train_set=train_set, test_set=test_set, spark=spark)
logging.info("Data processing completed successfully")
except Exception as e:
logging.error(f"Data processing failed: {str(e)}")
raise
preprocessed_features = self.preprocessor.fit_transform(self.df) | ||
preprocessed_features["click"] = self.df[AdClickDataColumns.click].astype("int64") | ||
self.df = preprocessed_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential data leakage in preprocessing pipeline.
The target variable 'click' is being added back to the preprocessed features. This could lead to data leakage if the same DataFrame is used for both feature extraction and target variable. Consider keeping features and target separate.
Apply this diff to separate features and target:
- preprocessed_features = self.preprocessor.fit_transform(self.df)
- preprocessed_features["click"] = self.df[AdClickDataColumns.click].astype("int64")
- self.df = preprocessed_features
+ self.X = self.preprocessor.fit_transform(self.df)
+ self.y = self.df[AdClickDataColumns.click].astype("int64")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
preprocessed_features = self.preprocessor.fit_transform(self.df) | |
preprocessed_features["click"] = self.df[AdClickDataColumns.click].astype("int64") | |
self.df = preprocessed_features | |
self.X = self.preprocessor.fit_transform(self.df) | |
self.y = self.df[AdClickDataColumns.click].astype("int64") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved! couple of comments, overall looks great!
) | ||
|
||
|
||
def get_git_info(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! this is a great option for local development, might be overkill though as it will be part of cd pipeline
"update_timestamp_utc", to_utc_timestamp(current_timestamp(), "UTC") | ||
) | ||
|
||
train_set_with_timestamp.write.option("overwriteSchema", "true").mode("overwrite").saveAsTable( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why overwrite? you will not be able to use it later for processing of new data.
Summary by CodeRabbit
New Features
Bug Fixes
DataProcessor
class, including error handling and data processing improvements.Documentation