@@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs):
127127 object_classes_to_test = [
128128 est for est in self ._all_objects () if not self .is_excluded (test_name , est )
129129 ]
130- object_names = [est .__name__ for est in object_classes_to_test ]
130+ object_names = [est .name () for est in object_classes_to_test ]
131131
132132 return object_classes_to_test , object_names
133133
@@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs):
156156 trainer_kwargs: dict
157157 ranges over all kwargs for the trainer
158158 """
159- # call _generate_object_class to get all the classes
160- object_meta_to_test , _ = self ._generate_object_metadata (test_name = test_name )
159+ if "object_metadata" in kwargs .keys ():
160+ obj_meta = kwargs ["object_metadata" ]
161+ else :
162+ return []
161163
162- # create instances from the classes
163- train_kwargs_to_test = []
164- train_kwargs_names = []
165- # retrieve all object parameters if multiple, construct instances
166- for est in object_meta_to_test :
167- est_name = est .__name__
168- all_train_kwargs = est .get_test_train_params ()
169- train_kwargs_to_test += all_train_kwargs
170- rg = range (len (all_train_kwargs ))
171- train_kwargs_names += [f"{ est_name } _{ i } " for i in rg ]
164+ all_train_kwargs = obj_meta .get_test_train_params ()
165+ rg = range (len (all_train_kwargs ))
166+ train_kwargs_names = [str (i ) for i in rg ]
172167
173- return train_kwargs_to_test , train_kwargs_names
168+ return all_train_kwargs , train_kwargs_names
174169
175170
176171def _integration (
@@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class):
278273 doctest .run_docstring_examples (object_class , globals ())
279274
280275 def test_integration (
281- self , object_class , trainer_kwargs , data_with_covariates , tmp_path
276+ self ,
277+ object_metadata ,
278+ trainer_kwargs ,
279+ data_with_covariates ,
280+ tmp_path ,
282281 ):
283282 """Fails for certain, for testing."""
284283 from pytorch_forecasting .metrics import NegativeBinomialDistributionLoss
285284
285+ object_class = object_metadata .get_model_cls ()
286+
286287 if "loss" in trainer_kwargs and isinstance (
287288 trainer_kwargs ["loss" ], NegativeBinomialDistributionLoss
288289 ):
0 commit comments