@@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
7171    object_instance: instance of estimator inheriting from BaseObject 
7272        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS 
7373        instances are generated by create_test_instance class method of object_class 
74+     trainer_kwargs: list of dict 
75+         ranges over dictionaries of kwargs for the trainer 
7476    """ 
7577
7678    # overrides object retrieval in scikit-base 
@@ -111,6 +113,7 @@ def _all_objects(self):
111113        "object_metadata" ,
112114        "object_class" ,
113115        "object_instance" ,
116+         "trainer_kwargs" ,
114117    ]
115118
116119    def  _generate_object_metadata (self , test_name , ** kwargs ):
@@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs):
145148
146149        return  object_classes_to_test , object_names 
147150
151+     def  _generate_trainer_kwargs (self , test_name , ** kwargs ):
152+         """Return kwargs for the trainer. 
153+ 
154+         Fixtures parametrized 
155+         --------------------- 
156+         trainer_kwargs: dict 
157+             ranges over all kwargs for the trainer 
158+         """ 
159+         # call _generate_object_class to get all the classes 
160+         object_meta_to_test , _  =  self ._generate_object_metadata (test_name = test_name )
161+ 
162+         # create instances from the classes 
163+         train_kwargs_to_test  =  []
164+         train_kwargs_names  =  []
165+         # retrieve all object parameters if multiple, construct instances 
166+         for  est  in  object_meta_to_test :
167+             est_name  =  est .__name__ 
168+             all_train_kwargs  =  est .get_test_train_params ()
169+             train_kwargs_to_test  +=  all_train_kwargs 
170+             rg  =  range (len (all_train_kwargs ))
171+             train_kwargs_names  +=  [f"{ est_name }  _{ i }  "  for  i  in  rg ]
172+ 
173+         return  train_kwargs_to_test , train_kwargs_names 
174+ 
148175
149176def  _integration (
150177    estimator_cls ,
@@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class):
250277
251278        doctest .run_docstring_examples (object_class , globals ())
252279
253-     def  test_certain_failure (self , object_class ):
280+     def  test_integration (
281+         self , object_class , trainer_kwargs , data_with_covariates , tmp_path 
282+     ):
254283        """Fails for certain, for testing.""" 
255-         assert  False 
284+         from  pytorch_forecasting .metrics  import  NegativeBinomialDistributionLoss 
285+ 
286+         if  "loss"  in  trainer_kwargs  and  isinstance (
287+             trainer_kwargs ["loss" ], NegativeBinomialDistributionLoss 
288+         ):
289+             data_with_covariates  =  data_with_covariates .assign (
290+                 volume = lambda  x : x .volume .round ()
291+             )
292+         _integration (object_class , data_with_covariates , tmp_path , ** trainer_kwargs )
0 commit comments