6
6
import lightning .pytorch as pl
7
7
from lightning .pytorch .callbacks import EarlyStopping
8
8
from lightning .pytorch .loggers import TensorBoardLogger
9
- from skbase .testing import (
10
- BaseFixtureGenerator as _BaseFixtureGenerator ,
11
- TestAllObjects as _TestAllObjects ,
12
- )
9
+ import pytest
10
+ from skbase .testing import BaseFixtureGenerator as _BaseFixtureGenerator
13
11
14
12
from pytorch_forecasting ._registry import all_objects
15
13
from pytorch_forecasting .tests ._config import EXCLUDE_ESTIMATORS , EXCLUDED_TESTS
@@ -110,10 +108,43 @@ def _all_objects(self):
110
108
111
109
# which sequence the conditional fixtures are generated in
112
110
fixture_sequence = [
111
+ "object_metadata" ,
113
112
"object_class" ,
114
113
"object_instance" ,
115
114
]
116
115
116
+ def _generate_object_metadata (self , test_name , ** kwargs ):
117
+ """Return object class fixtures.
118
+
119
+ Fixtures parametrized
120
+ ---------------------
121
+ object_class: object inheriting from BaseObject
122
+ ranges over all object classes not excluded by self.excluded_tests
123
+ """
124
+ object_classes_to_test = [
125
+ est for est in self ._all_objects () if not self .is_excluded (test_name , est )
126
+ ]
127
+ object_names = [est .__name__ for est in object_classes_to_test ]
128
+
129
+ return object_classes_to_test , object_names
130
+
131
+ def _generate_object_class (self , test_name , ** kwargs ):
132
+ """Return object class fixtures.
133
+
134
+ Fixtures parametrized
135
+ ---------------------
136
+ object_class: object inheriting from BaseObject
137
+ ranges over all object classes not excluded by self.excluded_tests
138
+ """
139
+ all_metadata = self ._all_objects ()
140
+ all_cls = [est .get_model_cls () for est in all_metadata ]
141
+ object_classes_to_test = [
142
+ est for est in all_cls if not self .is_excluded (test_name , est )
143
+ ]
144
+ object_names = [est .__name__ for est in object_classes_to_test ]
145
+
146
+ return object_classes_to_test , object_names
147
+
117
148
118
149
def _integration (
119
150
estimator_cls ,
@@ -210,7 +241,7 @@ def _integration(
210
241
)
211
242
212
243
213
- class TestAllPtForecasters (PackageConfig , BaseFixtureGenerator , _TestAllObjects ):
244
+ class TestAllPtForecasters (PackageConfig , BaseFixtureGenerator ):
214
245
"""Generic tests for all objects in the mini package."""
215
246
216
247
def test_doctest_examples (self , object_class ):
@@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class):
219
250
220
251
doctest .run_docstring_examples (object_class , globals ())
221
252
222
- def certain_failure (self , object_class ):
253
+ def test_certain_failure (self , object_class ):
223
254
"""Fails for certain, for testing."""
224
255
assert False
0 commit comments