Skip to content

Commit 57dfe3a

Browse files
committed
test folders
1 parent 7de5285 commit 57dfe3a

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

pytest.ini

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ addopts =
1010
--no-cov-on-fail
1111

1212
markers =
13-
testpaths = tests/
13+
testpaths =
14+
tests/
15+
pytorch_forecasting/tests/
1416
log_cli_level = ERROR
1517
log_format = %(asctime)s %(levelname)s %(message)s
1618
log_date_format = %Y-%m-%d %H:%M:%S

pytorch_forecasting/tests/test_all_estimators.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
import lightning.pytorch as pl
77
from lightning.pytorch.callbacks import EarlyStopping
88
from lightning.pytorch.loggers import TensorBoardLogger
9-
from skbase.testing import (
10-
BaseFixtureGenerator as _BaseFixtureGenerator,
11-
TestAllObjects as _TestAllObjects,
12-
)
9+
import pytest
10+
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator
1311

1412
from pytorch_forecasting._registry import all_objects
1513
from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
@@ -110,10 +108,43 @@ def _all_objects(self):
110108

111109
# which sequence the conditional fixtures are generated in
112110
fixture_sequence = [
111+
"object_metadata",
113112
"object_class",
114113
"object_instance",
115114
]
116115

116+
def _generate_object_metadata(self, test_name, **kwargs):
117+
"""Return object class fixtures.
118+
119+
Fixtures parametrized
120+
---------------------
121+
object_class: object inheriting from BaseObject
122+
ranges over all object classes not excluded by self.excluded_tests
123+
"""
124+
object_classes_to_test = [
125+
est for est in self._all_objects() if not self.is_excluded(test_name, est)
126+
]
127+
object_names = [est.__name__ for est in object_classes_to_test]
128+
129+
return object_classes_to_test, object_names
130+
131+
def _generate_object_class(self, test_name, **kwargs):
132+
"""Return object class fixtures.
133+
134+
Fixtures parametrized
135+
---------------------
136+
object_class: object inheriting from BaseObject
137+
ranges over all object classes not excluded by self.excluded_tests
138+
"""
139+
all_metadata = self._all_objects()
140+
all_cls = [est.get_model_cls() for est in all_metadata]
141+
object_classes_to_test = [
142+
est for est in all_cls if not self.is_excluded(test_name, est)
143+
]
144+
object_names = [est.__name__ for est in object_classes_to_test]
145+
146+
return object_classes_to_test, object_names
147+
117148

118149
def _integration(
119150
estimator_cls,
@@ -210,7 +241,7 @@ def _integration(
210241
)
211242

212243

213-
class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects):
244+
class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator):
214245
"""Generic tests for all objects in the mini package."""
215246

216247
def test_doctest_examples(self, object_class):
@@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class):
219250

220251
doctest.run_docstring_examples(object_class, globals())
221252

222-
def certain_failure(self, object_class):
253+
def test_certain_failure(self, object_class):
223254
"""Fails for certain, for testing."""
224255
assert False

0 commit comments

Comments
 (0)