From 8511cb729d0f20a365513279fff982842095f0e7 Mon Sep 17 00:00:00 2001 From: Diego Marvid Date: Fri, 1 Nov 2024 13:37:43 -0300 Subject: [PATCH] fix tests --- ml_garden/core/pipeline.py | 2 +- tests/core/steps/test_tabular_split.py | 7 +++++++ tests/core/test_model_registry.py | 7 +++---- tests/core/test_step_registry.py | 16 ++++++++-------- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/ml_garden/core/pipeline.py b/ml_garden/core/pipeline.py index d75266e..6fa3a7e 100644 --- a/ml_garden/core/pipeline.py +++ b/ml_garden/core/pipeline.py @@ -339,7 +339,7 @@ def plot_feature_importance(df: pd.DataFrame) -> None: mode_name = "train" if data.is_train else "predict" run = ( f"{self.__class__.__name__}_{mode_name}_" - f"{datetime.now(datetime.UTC).strftime('%Y%m%d_%H%M%S')}" + f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" # noqa: DTZ005 ) self.logger.info("Run name not provided. Using default run name: %s", run) diff --git a/tests/core/steps/test_tabular_split.py b/tests/core/steps/test_tabular_split.py index bec21ca..7e65876 100644 --- a/tests/core/steps/test_tabular_split.py +++ b/tests/core/steps/test_tabular_split.py @@ -3,11 +3,18 @@ import pytest from ml_garden.core import DataContainer +from ml_garden.core.random_state_generator import RandomStateManager from ml_garden.core.steps import TabularSplitStep # ruff: noqa: ERA001 +@pytest.fixture(autouse=True) +def _setup_random_state() -> None: + """Initialize random state before each test.""" + RandomStateManager.initialize(seed=42) + + @pytest.fixture() def input_data() -> pd.DataFrame: # Data as a dictionary diff --git a/tests/core/test_model_registry.py b/tests/core/test_model_registry.py index ad1a1a1..ac84932 100644 --- a/tests/core/test_model_registry.py +++ b/tests/core/test_model_registry.py @@ -22,7 +22,7 @@ def test_register_model() -> None: def test_register_non_model_class() -> None: registry = ModelRegistry() - with pytest.raises(ValueError, match=" must be a subclass of Model"): + with pytest.raises(TypeError, match=" must be a subclass of Model"): registry.register_model(int) @@ -90,6 +90,5 @@ def test_auto_register_models_import_error(mock_import_module: MagicMock) -> Non mock_import_module.side_effect = ImportError registry = ModelRegistry() - registry.auto_register_models_from_package("invalid_package") - - assert len(registry.get_all_model_classes()) == 0 + with pytest.raises(ImportError): + registry.auto_register_models_from_package("invalid_package") diff --git a/tests/core/test_step_registry.py b/tests/core/test_step_registry.py index fc9ca16..90c4279 100644 --- a/tests/core/test_step_registry.py +++ b/tests/core/test_step_registry.py @@ -22,7 +22,7 @@ def test_register_step() -> None: def test_register_non_step_class() -> None: registry = StepRegistry() - with pytest.raises(ValueError, match="must be a subclass of PipelineStep"): + with pytest.raises(TypeError, match="must be a subclass of PipelineStep"): registry.register_step(int) @@ -87,12 +87,11 @@ def test_auto_register_steps_from_package( @patch("ml_garden.core.step_registry.importlib.import_module") def test_auto_register_steps_import_error(mock_import_module: MagicMock) -> None: - mock_import_module.side_effect = ImportError + mock_import_module.side_effect = ImportError("Package not found") registry = StepRegistry() - registry.auto_register_steps_from_package("invalid_package") - - assert len(registry.get_all_step_classes()) == 0 + with pytest.raises(ImportError, match="Failed to import package: invalid_package"): + registry.auto_register_steps_from_package("invalid_package") @patch("ml_garden.core.step_registry.os.listdir") @@ -139,6 +138,7 @@ def test_load_and_register_custom_steps_exception( mock_spec.loader.exec_module.side_effect = Exception("Test Exception") registry = StepRegistry() - registry.load_and_register_custom_steps("custom_steps_path") - - assert len(registry.get_all_step_classes()) == 0 + with pytest.raises( + ImportError, match="Failed to load module: custom_step. Error: Test Exception" + ): + registry.load_and_register_custom_steps("custom_steps_path")