Skip to content

Commit

Permalink
refactor: refactor code to use new Config class
Browse files Browse the repository at this point in the history
  • Loading branch information
Garett601 committed Oct 23, 2024
1 parent ed3e464 commit 6291fa8
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 99 deletions.
15 changes: 8 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from loguru import logger

from power_consumption.config import Config
from power_consumption.model import ConsumptionModel
from power_consumption.preprocessing import DataProcessor
from power_consumption.utils import load_config, plot_actual_vs_predicted, plot_feature_importance, visualise_results
from power_consumption.utils import plot_actual_vs_predicted, plot_feature_importance, visualise_results

# Load configuration
config = load_config("./configs/project_configs.yml")
config = Config.from_yaml("./configs/project_configs.yml")

# Initialise data processor and preprocess data
processor = DataProcessor(dataset_id=849, config=config)
processor = DataProcessor(config=config)
processor.preprocess_data()

# Split data into train and test sets
Expand All @@ -31,17 +32,17 @@

# Log Mean Squared Error for each target
logger.info("Mean Squared Error for each target:")
for i, target in enumerate(config["target"]):
for i, target in enumerate(config.target.target):
logger.info(f"{target}: {mse[i]:.4f}")

# Log R-squared for each target
logger.info("\nR-squared for each target:")
for i, target in enumerate(config["target"]):
for i, target in enumerate(config.target.target):
logger.info(f"{target}: {r2[i]:.4f}")

# Visualise results
visualise_results(y_test, y_pred, config["target"])
plot_actual_vs_predicted(y_test.values, y_pred, config["target"])
visualise_results(y_test, y_pred, config.target.target)
plot_actual_vs_predicted(y_test.values, y_pred, config.target.target)

# Get and plot feature importance
feature_importance, feature_names = model.get_feature_importance()
Expand Down
18 changes: 9 additions & 9 deletions notebooks/week_1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
"source": [
"from power_consumption.preprocessing import DataProcessor\n",
"from power_consumption.model import ConsumptionModel\n",
"\n",
"from power_consumption.utils import load_config"
"from power_consumption.config import Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"config = load_config(\"../configs/project_configs.yml\")"
"config = Config.from_yaml(\"../configs/project_configs.yml\")"
]
},
{
Expand All @@ -27,7 +26,8 @@
"metadata": {},
"outputs": [],
"source": [
"processor = DataProcessor(dataset_id=849, config=config)\n",
"\n",
"processor = DataProcessor(config=config)\n",
"processor.preprocess_data()\n",
"X_train, X_test, y_train, y_test = processor.split_data()\n",
"X_train_transformed, X_test_transformed = processor.fit_transform_features(X_train, X_test)"
Expand Down Expand Up @@ -76,11 +76,11 @@
"outputs": [],
"source": [
"print(\"Mean Squared Error for each target:\")\n",
"for i, target in enumerate(config['target']):\n",
"for i, target in enumerate(config.target.target):\n",
" print(f\"{target}: {mse[i]:.4f}\")\n",
"\n",
"print(\"\\nR-squared for each target:\")\n",
"for i, target in enumerate(config['target']):\n",
"for i, target in enumerate(config.target.target):\n",
" print(f\"{target}: {r2[i]:.4f}\")"
]
},
Expand All @@ -99,8 +99,8 @@
"metadata": {},
"outputs": [],
"source": [
"visualise_results(y_test, y_pred, config['target'])\n",
"plot_actual_vs_predicted(y_test.values, y_pred, config['target'])\n",
"visualise_results(y_test, y_pred, config.target.target)\n",
"plot_actual_vs_predicted(y_test.values, y_pred, config.target.target)\n",
"\n",
"feature_importance, feature_names = model.get_feature_importance()\n",
"plot_feature_importance(feature_importance, feature_names)"
Expand Down
48 changes: 0 additions & 48 deletions power_consumption/main.py

This file was deleted.

10 changes: 6 additions & 4 deletions power_consumption/model/rf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Main module for the power consumption model.
"""

from typing import Dict, Self, Tuple
from typing import Self, Tuple

import numpy as np
from sklearn.compose import ColumnTransformer
Expand All @@ -11,6 +11,8 @@
from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import Pipeline

from power_consumption.config import Config


class ConsumptionModel:
"""
Expand All @@ -31,7 +33,7 @@ class ConsumptionModel:
The complete model pipeline including preprocessing and regression steps.
"""

def __init__(self, preprocessor: ColumnTransformer, config: Dict) -> None:
def __init__(self, preprocessor: ColumnTransformer, config: Config) -> None:
"""
Initialise the ConsumptionModel.
"""
Expand All @@ -43,8 +45,8 @@ def __init__(self, preprocessor: ColumnTransformer, config: Dict) -> None:
"regressor",
MultiOutputRegressor(
RandomForestRegressor(
n_estimators=config["parameters"]["n_estimators"],
max_depth=config["parameters"]["max_depth"],
n_estimators=config.hyperparameters.n_estimators,
max_depth=config.hyperparameters.max_depth,
random_state=42,
)
),
Expand Down
19 changes: 11 additions & 8 deletions power_consumption/preprocessing/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Data preprocessing module for the power consumption dataset."""

from __future__ import annotations

import os
from typing import Any, Dict, Optional, Tuple
from typing import Optional, Tuple

import numpy as np
import pandas as pd
Expand All @@ -12,18 +14,19 @@
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from ucimlrepo import fetch_ucirepo

from power_consumption.config import Config
from power_consumption.schemas.processed_data import PowerConsumptionSchema


class DataProcessor:
def __init__(self, config: Dict[str, Any]):
def __init__(self, config: Config):
"""
Initialise the DataProcessor.
"""
self.data: Optional[pd.DataFrame] = None
self.X: Optional[pd.DataFrame] = None
self.y: Optional[pd.DataFrame] = None
self.config: Dict[str, Any] = config
self.config: Config = config
self.preprocessor: Optional[ColumnTransformer] = None
self.load_data()

Expand All @@ -41,7 +44,7 @@ def load_data(self) -> None:
If loading from UCI ML Repository fails, the method will attempt to load
the data from '../data/Tetuan City power consumption.csv'.
"""
dataset_id = self.config["dataset"]["id"]
dataset_id = self.config.dataset.id
try:
dataset = fetch_ucirepo(id=dataset_id)
logger.info(
Expand Down Expand Up @@ -76,8 +79,8 @@ def create_preprocessor(self) -> ColumnTransformer:
ColumnTransformer
The preprocessing pipeline.
"""
numeric_features = self.config["num_features"]
categorical_features = self.config["cat_features"]
numeric_features = self.config.features.num_features
categorical_features = self.config.features.cat_features

logger.info(f"Numeric features: {numeric_features}")
logger.info(f"Categorical features: {categorical_features}")
Expand Down Expand Up @@ -141,8 +144,8 @@ def split_data(self, test_size: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame
tuple of pd.DataFrame
X_train, X_test, y_train, y_test
"""
target_columns = self.config["target"]
feature_columns = self.config["num_features"] + self.config["cat_features"]
target_columns = self.config.target.target
feature_columns = self.config.features.num_features + self.config.features.cat_features

X = self.data[feature_columns]
y = self.data[target_columns]
Expand Down
23 changes: 0 additions & 23 deletions power_consumption/utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
"""Utility functions for the power consumption project."""

from typing import Any, Dict

import matplotlib.pyplot as plt
import numpy as np
import yaml
from loguru import logger


def load_config(config_path: str) -> Dict[str, Any]:
"""
Load configuration from a YAML file.
Parameters
----------
config_path : str
Path to the configuration YAML file.
Returns
-------
Dict[str, Any]
Configuration dictionary.
"""
logger.info(f"Loading configuration from {config_path}")
with open(config_path, "r") as file:
return yaml.safe_load(file)


def visualise_results(y_test, y_pred, target_names):
Expand Down

0 comments on commit 6291fa8

Please sign in to comment.