Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Nov 1, 2024
1 parent 1eb2eb2 commit 8511cb7
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ml_garden/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def plot_feature_importance(df: pd.DataFrame) -> None:
mode_name = "train" if data.is_train else "predict"
run = (
f"{self.__class__.__name__}_{mode_name}_"
f"{datetime.now(datetime.UTC).strftime('%Y%m%d_%H%M%S')}"
f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" # noqa: DTZ005
)
self.logger.info("Run name not provided. Using default run name: %s", run)

Expand Down
7 changes: 7 additions & 0 deletions tests/core/steps/test_tabular_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
import pytest

from ml_garden.core import DataContainer
from ml_garden.core.random_state_generator import RandomStateManager
from ml_garden.core.steps import TabularSplitStep

# ruff: noqa: ERA001


@pytest.fixture(autouse=True)
def _setup_random_state() -> None:
"""Initialize random state before each test."""
RandomStateManager.initialize(seed=42)


@pytest.fixture()
def input_data() -> pd.DataFrame:
# Data as a dictionary
Expand Down
7 changes: 3 additions & 4 deletions tests/core/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_register_model() -> None:

def test_register_non_model_class() -> None:
registry = ModelRegistry()
with pytest.raises(ValueError, match="<class 'int'> must be a subclass of Model"):
with pytest.raises(TypeError, match="<class 'int'> must be a subclass of Model"):
registry.register_model(int)


Expand Down Expand Up @@ -90,6 +90,5 @@ def test_auto_register_models_import_error(mock_import_module: MagicMock) -> Non
mock_import_module.side_effect = ImportError

registry = ModelRegistry()
registry.auto_register_models_from_package("invalid_package")

assert len(registry.get_all_model_classes()) == 0
with pytest.raises(ImportError):
registry.auto_register_models_from_package("invalid_package")
16 changes: 8 additions & 8 deletions tests/core/test_step_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_register_step() -> None:

def test_register_non_step_class() -> None:
registry = StepRegistry()
with pytest.raises(ValueError, match="must be a subclass of PipelineStep"):
with pytest.raises(TypeError, match="must be a subclass of PipelineStep"):
registry.register_step(int)


Expand Down Expand Up @@ -87,12 +87,11 @@ def test_auto_register_steps_from_package(

@patch("ml_garden.core.step_registry.importlib.import_module")
def test_auto_register_steps_import_error(mock_import_module: MagicMock) -> None:
mock_import_module.side_effect = ImportError
mock_import_module.side_effect = ImportError("Package not found")

registry = StepRegistry()
registry.auto_register_steps_from_package("invalid_package")

assert len(registry.get_all_step_classes()) == 0
with pytest.raises(ImportError, match="Failed to import package: invalid_package"):
registry.auto_register_steps_from_package("invalid_package")


@patch("ml_garden.core.step_registry.os.listdir")
Expand Down Expand Up @@ -139,6 +138,7 @@ def test_load_and_register_custom_steps_exception(
mock_spec.loader.exec_module.side_effect = Exception("Test Exception")

registry = StepRegistry()
registry.load_and_register_custom_steps("custom_steps_path")

assert len(registry.get_all_step_classes()) == 0
with pytest.raises(
ImportError, match="Failed to load module: custom_step. Error: Test Exception"
):
registry.load_and_register_custom_steps("custom_steps_path")

0 comments on commit 8511cb7

Please sign in to comment.