@@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
71
71
object_instance: instance of estimator inheriting from BaseObject
72
72
ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
73
73
instances are generated by create_test_instance class method of object_class
74
+ trainer_kwargs: list of dict
75
+ ranges over dictionaries of kwargs for the trainer
74
76
"""
75
77
76
78
# overrides object retrieval in scikit-base
@@ -111,6 +113,7 @@ def _all_objects(self):
111
113
"object_metadata" ,
112
114
"object_class" ,
113
115
"object_instance" ,
116
+ "trainer_kwargs" ,
114
117
]
115
118
116
119
def _generate_object_metadata (self , test_name , ** kwargs ):
@@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs):
145
148
146
149
return object_classes_to_test , object_names
147
150
151
+ def _generate_trainer_kwargs (self , test_name , ** kwargs ):
152
+ """Return kwargs for the trainer.
153
+
154
+ Fixtures parametrized
155
+ ---------------------
156
+ trainer_kwargs: dict
157
+ ranges over all kwargs for the trainer
158
+ """
159
+ # call _generate_object_class to get all the classes
160
+ object_meta_to_test , _ = self ._generate_object_metadata (test_name = test_name )
161
+
162
+ # create instances from the classes
163
+ train_kwargs_to_test = []
164
+ train_kwargs_names = []
165
+ # retrieve all object parameters if multiple, construct instances
166
+ for est in object_meta_to_test :
167
+ est_name = est .__name__
168
+ all_train_kwargs = est .get_test_train_params ()
169
+ train_kwargs_to_test += all_train_kwargs
170
+ rg = range (len (all_train_kwargs ))
171
+ train_kwargs_names += [f"{ est_name } _{ i } " for i in rg ]
172
+
173
+ return train_kwargs_to_test , train_kwargs_names
174
+
148
175
149
176
def _integration (
150
177
estimator_cls ,
@@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class):
250
277
251
278
doctest .run_docstring_examples (object_class , globals ())
252
279
253
- def test_certain_failure (self , object_class ):
280
+ def test_integration (
281
+ self , object_class , trainer_kwargs , data_with_covariates , tmp_path
282
+ ):
254
283
"""Fails for certain, for testing."""
255
- assert False
284
+ from pytorch_forecasting .metrics import NegativeBinomialDistributionLoss
285
+
286
+ if "loss" in trainer_kwargs and isinstance (
287
+ trainer_kwargs ["loss" ], NegativeBinomialDistributionLoss
288
+ ):
289
+ data_with_covariates = data_with_covariates .assign (
290
+ volume = lambda x : x .volume .round ()
291
+ )
292
+ _integration (object_class , data_with_covariates , tmp_path , ** trainer_kwargs )
0 commit comments