From b599df7e3af5b220b40a8d67c97151d70712f5cb Mon Sep 17 00:00:00 2001 From: Garett Sidwell <63107655+Garett601@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:48:56 +0200 Subject: [PATCH] test(tests/): refactor tests to use new Config class --- tests/conftest.py | 13 +++++-- tests/conftest.yml | 34 ++++++++++--------- tests/model/test_rf_model.py | 32 +++++++++++------ tests/preprocessing/test_data_preprocessor.py | 7 ++-- tests/test_utils.py | 15 -------- 5 files changed, 54 insertions(+), 47 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 489ca94..63b5b7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,16 @@ import pytest -import yaml from pathlib import Path +from power_consumption.config import Config @pytest.fixture def project_config(): + """ + Fixture to provide the project configuration. + + Returns + ------- + Config + An instance of the Config class populated with data from the YAML file. + """ config_path = Path(__file__).parent / "conftest.yml" - with open(config_path, "r") as file: - return yaml.safe_load(file) + return Config.from_yaml(str(config_path)) diff --git a/tests/conftest.yml b/tests/conftest.yml index a708976..af20127 100644 --- a/tests/conftest.yml +++ b/tests/conftest.yml @@ -1,29 +1,31 @@ catalog_name: heiaepgah71pwedmld01001 schema_name: power_consumption -parameters: +hyperparameters: learning_rate: 0.01 n_estimators: 1000 max_depth: 6 -num_features: - - Temperature - - Humidity - - Wind_Speed - - Hour - - Day - - Month - - general_diffuse_flows - - diffuse_flows +features: + num_features: + - Temperature + - Humidity + - Wind_Speed + - Hour + - Day + - Month + - general_diffuse_flows + - diffuse_flows -cat_features: - - DayOfWeek - - IsWeekend + cat_features: + - DayOfWeek + - IsWeekend target: - - Zone_1_Power_Consumption - - Zone_2_Power_Consumption - - Zone_3_Power_Consumption + target: + - Zone_1_Power_Consumption + - Zone_2_Power_Consumption + - Zone_3_Power_Consumption dataset: id: 849 diff --git a/tests/model/test_rf_model.py b/tests/model/test_rf_model.py index 2564bd0..31b9c41 100644 --- a/tests/model/test_rf_model.py +++ b/tests/model/test_rf_model.py @@ -5,6 +5,7 @@ from sklearn.pipeline import Pipeline from power_consumption.model import ConsumptionModel +from power_consumption.config import Config, Hyperparameters, Features, Target, Dataset @pytest.fixture def sample_data(): @@ -15,12 +16,21 @@ def sample_data(): @pytest.fixture def model_config(): - return { - "parameters": { - "n_estimators": 100, - "max_depth": 5 - } - } + return Config( + catalog_name="test_catalog", + schema_name="test_schema", + hyperparameters=Hyperparameters( + learning_rate=0.01, + n_estimators=100, + max_depth=5 + ), + features=Features( + num_features=["feature1", "feature2", "feature3", "feature4", "feature5"], + cat_features=[] + ), + target=Target(target=["target1", "target2"]), + dataset=Dataset(id=1) + ) @pytest.fixture def preprocessor(): @@ -32,8 +42,10 @@ def preprocessor(): def test_model_initialisation(preprocessor, model_config): model = ConsumptionModel(preprocessor, model_config) - assert model.config == model_config + assert isinstance(model.config, Config) assert isinstance(model.model, Pipeline) + assert model.model.named_steps["regressor"].estimator.n_estimators == model_config.hyperparameters.n_estimators + assert model.model.named_steps["regressor"].estimator.max_depth == model_config.hyperparameters.max_depth def test_model_train_and_predict(sample_data, preprocessor, model_config): X_train, y_train, X_test, _ = sample_data @@ -43,7 +55,7 @@ def test_model_train_and_predict(sample_data, preprocessor, model_config): assert trained_model is model y_pred = model.predict(X_test) - assert y_pred.shape == (20, 2) + assert y_pred.shape == (20, len(model_config.target.target)) def test_model_evaluate(sample_data, preprocessor, model_config): X_train, y_train, X_test, y_test = sample_data @@ -51,8 +63,8 @@ def test_model_evaluate(sample_data, preprocessor, model_config): model.train(X_train, y_train) mse, r2 = model.evaluate(X_test, y_test) - assert mse.shape == (2,) - assert r2.shape == (2,) + assert mse.shape == (len(model_config.target.target),) + assert r2.shape == (len(model_config.target.target),) assert np.all(mse >= 0) assert np.all(r2 <= 1) and np.all(r2 >= -1) diff --git a/tests/preprocessing/test_data_preprocessor.py b/tests/preprocessing/test_data_preprocessor.py index 0494269..f5764e0 100644 --- a/tests/preprocessing/test_data_preprocessor.py +++ b/tests/preprocessing/test_data_preprocessor.py @@ -2,6 +2,7 @@ import numpy as np import pytest from power_consumption.preprocessing.data_preprocessor import DataProcessor +from power_consumption.config import Config @pytest.fixture @@ -97,9 +98,9 @@ def test_split_data(data_processor_with_data): assert len(X_train) > len(X_test) assert len(y_train) > len(y_test) assert X_train.shape[1] == len( - data_processor_with_data.config["num_features"] - ) + len(data_processor_with_data.config["cat_features"]) - assert y_train.shape[1] == len(data_processor_with_data.config["target"]) + data_processor_with_data.config.features.num_features + ) + len(data_processor_with_data.config.features.cat_features) + assert y_train.shape[1] == len(data_processor_with_data.config.target.target) def test_fit_transform_features(data_processor_with_data): diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cbefed..74a3f53 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,27 +8,12 @@ import pytest from power_consumption.utils import ( - load_config, visualise_results, plot_actual_vs_predicted, plot_feature_importance, ) -def test_load_config(): - conftest_path = Path(__file__).parent / "conftest.yml" - - loaded_config = load_config(conftest_path) - - assert loaded_config["catalog_name"] == "heiaepgah71pwedmld01001" - assert loaded_config["schema_name"] == "power_consumption" - assert loaded_config["parameters"]["learning_rate"] == 0.01 - assert loaded_config["parameters"]["n_estimators"] == 1000 - assert loaded_config["parameters"]["max_depth"] == 6 - assert "Temperature" in loaded_config["num_features"] - assert "DayOfWeek" in loaded_config["cat_features"] - assert "Zone_1_Power_Consumption" in loaded_config["target"] - @pytest.mark.parametrize("n_targets", [1, 3]) def test_visualise_results(n_targets, mocker):