From 7277287c2891694e4ee70c0b863cf20e76ba5530 Mon Sep 17 00:00:00 2001 From: Garett Sidwell <63107655+Garett601@users.noreply.github.com> Date: Wed, 23 Oct 2024 13:16:22 +0200 Subject: [PATCH] refactor(data_preprocessor.py): improve data path resolution remove hardcoded paths and use Pathlib to resolve the data/ path for local read --- .../preprocessing/data_preprocessor.py | 31 +++-- tests/preprocessing/test_data_preprocessor.py | 13 +- uv.lock | 118 +++++++++--------- 3 files changed, 91 insertions(+), 71 deletions(-) diff --git a/power_consumption/preprocessing/data_preprocessor.py b/power_consumption/preprocessing/data_preprocessor.py index 5707c35..0ada763 100644 --- a/power_consumption/preprocessing/data_preprocessor.py +++ b/power_consumption/preprocessing/data_preprocessor.py @@ -2,7 +2,7 @@ from __future__ import annotations -import os +from pathlib import Path from typing import Optional, Tuple import numpy as np @@ -34,15 +34,21 @@ def load_data(self) -> None: """ Load the dataset from UCI ML Repository or local CSV file. - Parameters - ---------- - dataset_id : int - The ID of the dataset to fetch from UCI ML Repository. + This method attempts to load the dataset from the UCI ML Repository using the dataset ID + specified in the configuration. If that fails, it falls back to loading from a local CSV file. + + The dataset ID is obtained from the configuration (self.config.dataset.id). Notes ----- If loading from UCI ML Repository fails, the method will attempt to load - the data from '../data/Tetuan City power consumption.csv'. + the data from '../data/Tetuan City power consumption.csv' or + './data/Tetuan City power consumption.csv'. + + Raises + ------ + Exception + If both UCI ML Repository fetch and local CSV file loading fail. """ dataset_id = self.config.dataset.id try: @@ -54,12 +60,19 @@ def load_data(self) -> None: except Exception as e: logger.warning(f"Failed to load data from UCI ML Repository: {e}") logger.info("Attempting to load data from local CSV file") - csv_path = "../data/Tetuan City power consumption.csv" - if not os.path.exists(csv_path): - csv_path = "./data/Tetuan City power consumption.csv" + data_dir = Path(__file__).resolve().parents[2] / "data" + csv_filename = "Tetuan City power consumption.csv" + csv_path = data_dir / csv_filename + + if not csv_path.exists(): + csv_path = Path.cwd() / "data" / csv_filename + try: self.data = pd.read_csv(csv_path) logger.info(f"Successfully loaded data from {csv_path}") + except FileNotFoundError: + logger.error(f"CSV file not found at {csv_path}") + raise except Exception as e: logger.error(f"Failed to load data from {csv_path}: {e}") raise diff --git a/tests/preprocessing/test_data_preprocessor.py b/tests/preprocessing/test_data_preprocessor.py index f5764e0..2f438b0 100644 --- a/tests/preprocessing/test_data_preprocessor.py +++ b/tests/preprocessing/test_data_preprocessor.py @@ -3,6 +3,7 @@ import pytest from power_consumption.preprocessing.data_preprocessor import DataProcessor from power_consumption.config import Config +from pathlib import Path @pytest.fixture @@ -54,17 +55,23 @@ def test_load_data(project_config, mocker): def test_load_data_fallback(project_config, mocker): - mock_fetch = mocker.patch( + mocker.patch( "power_consumption.preprocessing.data_preprocessor.fetch_ucirepo", side_effect=Exception("UCI fetch failed"), ) + mock_read_csv = mocker.patch("pandas.read_csv") mock_read_csv.return_value = pd.DataFrame({"B": [4, 5, 6]}) + mock_path = mocker.patch("power_consumption.preprocessing.data_preprocessor.Path") + mock_path.return_value.exists.return_value = True + mock_path.return_value.resolve.return_value.parents.__getitem__.return_value = Path("/mocked/project/root") + processor = DataProcessor(config=project_config) - mock_fetch.assert_called_once_with(id=849) - mock_read_csv.assert_called_once_with("./data/Tetuan City power consumption.csv") + mock_read_csv.assert_called_once() + + assert processor.data is not None assert processor.data.equals(pd.DataFrame({"B": [4, 5, 6]})) diff --git a/uv.lock b/uv.lock index beb9c37..ece0c82 100644 --- a/uv.lock +++ b/uv.lock @@ -1066,65 +1066,6 @@ databricks = [ { name = "google-cloud-storage" }, ] -[[package]] -name = "mlops-with-databricks" -version = "0.0.1" -source = { virtual = "." } -dependencies = [ - { name = "cffi" }, - { name = "cloudpickle" }, - { name = "databricks-feature-engineering" }, - { name = "lightgbm" }, - { name = "loguru" }, - { name = "matplotlib" }, - { name = "mlflow" }, - { name = "numpy" }, - { name = "pandas" }, - { name = "pandera" }, - { name = "pyarrow" }, - { name = "scikit-learn" }, - { name = "scipy" }, - { name = "ucimlrepo" }, -] - -[package.optional-dependencies] -dev = [ - { name = "databricks-connect" }, - { name = "databricks-sdk" }, - { name = "ipykernel" }, - { name = "pip" }, - { name = "pytest" }, - { name = "pytest-cov" }, - { name = "pytest-mock" }, - { name = "pytest-sugar" }, -] - -[package.metadata] -requires-dist = [ - { name = "cffi", specifier = ">=1.17.1,<2" }, - { name = "cloudpickle", specifier = ">=3.0.0,<4" }, - { name = "databricks-connect", marker = "extra == 'dev'", specifier = ">=15.4.1,<16" }, - { name = "databricks-feature-engineering", specifier = ">=0.6,<1" }, - { name = "databricks-sdk", marker = "extra == 'dev'", specifier = ">=0.32.0,<0.33" }, - { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5,<7" }, - { name = "lightgbm", specifier = ">=4.5.0,<5" }, - { name = "loguru", specifier = ">=0.7.2" }, - { name = "matplotlib", specifier = ">=3.9.2,<4" }, - { name = "mlflow", specifier = ">=2.16.0,<3" }, - { name = "numpy", specifier = ">=1.26.4,<2" }, - { name = "pandas", specifier = ">=2.2.2,<3" }, - { name = "pandera", specifier = ">=0.20.4" }, - { name = "pip", marker = "extra == 'dev'", specifier = ">=24.2" }, - { name = "pyarrow", specifier = ">=15.0.2,<16" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" }, - { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0.0" }, - { name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.14.0" }, - { name = "pytest-sugar", marker = "extra == 'dev'", specifier = ">=1.0.0" }, - { name = "scikit-learn", specifier = ">=1.5.1,<2" }, - { name = "scipy", specifier = ">=1.14.1,<2" }, - { name = "ucimlrepo", specifier = ">=0.0.7" }, -] - [[package]] name = "multimethod" version = "1.10" @@ -1321,6 +1262,65 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "power-consumption" +version = "0.0.1" +source = { virtual = "." } +dependencies = [ + { name = "cffi" }, + { name = "cloudpickle" }, + { name = "databricks-feature-engineering" }, + { name = "lightgbm" }, + { name = "loguru" }, + { name = "matplotlib" }, + { name = "mlflow" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "pandera" }, + { name = "pyarrow" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "ucimlrepo" }, +] + +[package.optional-dependencies] +dev = [ + { name = "databricks-connect" }, + { name = "databricks-sdk" }, + { name = "ipykernel" }, + { name = "pip" }, + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, + { name = "pytest-sugar" }, +] + +[package.metadata] +requires-dist = [ + { name = "cffi", specifier = ">=1.17.1,<2" }, + { name = "cloudpickle", specifier = ">=3.0.0,<4" }, + { name = "databricks-connect", marker = "extra == 'dev'", specifier = ">=15.4.1,<16" }, + { name = "databricks-feature-engineering", specifier = ">=0.6,<1" }, + { name = "databricks-sdk", marker = "extra == 'dev'", specifier = ">=0.32.0,<0.33" }, + { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5,<7" }, + { name = "lightgbm", specifier = ">=4.5.0,<5" }, + { name = "loguru", specifier = ">=0.7.2" }, + { name = "matplotlib", specifier = ">=3.9.2,<4" }, + { name = "mlflow", specifier = ">=2.16.0,<3" }, + { name = "numpy", specifier = ">=1.26.4,<2" }, + { name = "pandas", specifier = ">=2.2.2,<3" }, + { name = "pandera", specifier = ">=0.20.4" }, + { name = "pip", marker = "extra == 'dev'", specifier = ">=24.2" }, + { name = "pyarrow", specifier = ">=15.0.2,<16" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" }, + { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0.0" }, + { name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.14.0" }, + { name = "pytest-sugar", marker = "extra == 'dev'", specifier = ">=1.0.0" }, + { name = "scikit-learn", specifier = ">=1.5.1,<2" }, + { name = "scipy", specifier = ">=1.14.1,<2" }, + { name = "ucimlrepo", specifier = ">=0.0.7" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.48"