@@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
7171 object_instance: instance of estimator inheriting from BaseObject
7272 ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
7373 instances are generated by create_test_instance class method of object_class
74+ trainer_kwargs: list of dict
75+ ranges over dictionaries of kwargs for the trainer
7476 """
7577
7678 # overrides object retrieval in scikit-base
@@ -111,6 +113,7 @@ def _all_objects(self):
111113 "object_metadata" ,
112114 "object_class" ,
113115 "object_instance" ,
116+ "trainer_kwargs" ,
114117 ]
115118
116119 def _generate_object_metadata (self , test_name , ** kwargs ):
@@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs):
145148
146149 return object_classes_to_test , object_names
147150
151+ def _generate_trainer_kwargs (self , test_name , ** kwargs ):
152+ """Return kwargs for the trainer.
153+
154+ Fixtures parametrized
155+ ---------------------
156+ trainer_kwargs: dict
157+ ranges over all kwargs for the trainer
158+ """
159+ # call _generate_object_class to get all the classes
160+ object_meta_to_test , _ = self ._generate_object_metadata (test_name = test_name )
161+
162+ # create instances from the classes
163+ train_kwargs_to_test = []
164+ train_kwargs_names = []
165+ # retrieve all object parameters if multiple, construct instances
166+ for est in object_meta_to_test :
167+ est_name = est .__name__
168+ all_train_kwargs = est .get_test_train_params ()
169+ train_kwargs_to_test += all_train_kwargs
170+ rg = range (len (all_train_kwargs ))
171+ train_kwargs_names += [f"{ est_name } _{ i } " for i in rg ]
172+
173+ return train_kwargs_to_test , train_kwargs_names
174+
148175
149176def _integration (
150177 estimator_cls ,
@@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class):
250277
251278 doctest .run_docstring_examples (object_class , globals ())
252279
253- def test_certain_failure (self , object_class ):
280+ def test_integration (
281+ self , object_class , trainer_kwargs , data_with_covariates , tmp_path
282+ ):
254283 """Fails for certain, for testing."""
255- assert False
284+ from pytorch_forecasting .metrics import NegativeBinomialDistributionLoss
285+
286+ if "loss" in trainer_kwargs and isinstance (
287+ trainer_kwargs ["loss" ], NegativeBinomialDistributionLoss
288+ ):
289+ data_with_covariates = data_with_covariates .assign (
290+ volume = lambda x : x .volume .round ()
291+ )
292+ _integration (object_class , data_with_covariates , tmp_path , ** trainer_kwargs )
0 commit comments