Skip to content

Commit

Permalink
test folders
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Feb 23, 2025
1 parent 7de5285 commit 57dfe3a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ addopts =
--no-cov-on-fail

markers =
testpaths = tests/
testpaths =
tests/
pytorch_forecasting/tests/
log_cli_level = ERROR
log_format = %(asctime)s %(levelname)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S
Expand Down
43 changes: 37 additions & 6 deletions pytorch_forecasting/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from skbase.testing import (
BaseFixtureGenerator as _BaseFixtureGenerator,
TestAllObjects as _TestAllObjects,
)
import pytest
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator

from pytorch_forecasting._registry import all_objects
from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
Expand Down Expand Up @@ -110,10 +108,43 @@ def _all_objects(self):

# which sequence the conditional fixtures are generated in
fixture_sequence = [
"object_metadata",
"object_class",
"object_instance",
]

def _generate_object_metadata(self, test_name, **kwargs):
"""Return object class fixtures.
Fixtures parametrized
---------------------
object_class: object inheriting from BaseObject
ranges over all object classes not excluded by self.excluded_tests
"""
object_classes_to_test = [
est for est in self._all_objects() if not self.is_excluded(test_name, est)
]
object_names = [est.__name__ for est in object_classes_to_test]

return object_classes_to_test, object_names

def _generate_object_class(self, test_name, **kwargs):
"""Return object class fixtures.
Fixtures parametrized
---------------------
object_class: object inheriting from BaseObject
ranges over all object classes not excluded by self.excluded_tests
"""
all_metadata = self._all_objects()
all_cls = [est.get_model_cls() for est in all_metadata]
object_classes_to_test = [
est for est in all_cls if not self.is_excluded(test_name, est)
]
object_names = [est.__name__ for est in object_classes_to_test]

return object_classes_to_test, object_names


def _integration(
estimator_cls,
Expand Down Expand Up @@ -210,7 +241,7 @@ def _integration(
)


class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects):
class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator):
"""Generic tests for all objects in the mini package."""

def test_doctest_examples(self, object_class):
Expand All @@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class):

doctest.run_docstring_examples(object_class, globals())

def certain_failure(self, object_class):
def test_certain_failure(self, object_class):
"""Fails for certain, for testing."""
assert False

0 comments on commit 57dfe3a

Please sign in to comment.