@@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs):
127
127
object_classes_to_test = [
128
128
est for est in self ._all_objects () if not self .is_excluded (test_name , est )
129
129
]
130
- object_names = [est .__name__ for est in object_classes_to_test ]
130
+ object_names = [est .name () for est in object_classes_to_test ]
131
131
132
132
return object_classes_to_test , object_names
133
133
@@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs):
156
156
trainer_kwargs: dict
157
157
ranges over all kwargs for the trainer
158
158
"""
159
- # call _generate_object_class to get all the classes
160
- object_meta_to_test , _ = self ._generate_object_metadata (test_name = test_name )
159
+ if "object_metadata" in kwargs .keys ():
160
+ obj_meta = kwargs ["object_metadata" ]
161
+ else :
162
+ return []
161
163
162
- # create instances from the classes
163
- train_kwargs_to_test = []
164
- train_kwargs_names = []
165
- # retrieve all object parameters if multiple, construct instances
166
- for est in object_meta_to_test :
167
- est_name = est .__name__
168
- all_train_kwargs = est .get_test_train_params ()
169
- train_kwargs_to_test += all_train_kwargs
170
- rg = range (len (all_train_kwargs ))
171
- train_kwargs_names += [f"{ est_name } _{ i } " for i in rg ]
164
+ all_train_kwargs = obj_meta .get_test_train_params ()
165
+ rg = range (len (all_train_kwargs ))
166
+ train_kwargs_names = [str (i ) for i in rg ]
172
167
173
- return train_kwargs_to_test , train_kwargs_names
168
+ return all_train_kwargs , train_kwargs_names
174
169
175
170
176
171
def _integration (
@@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class):
278
273
doctest .run_docstring_examples (object_class , globals ())
279
274
280
275
def test_integration (
281
- self , object_class , trainer_kwargs , data_with_covariates , tmp_path
276
+ self ,
277
+ object_metadata ,
278
+ trainer_kwargs ,
279
+ data_with_covariates ,
280
+ tmp_path ,
282
281
):
283
282
"""Fails for certain, for testing."""
284
283
from pytorch_forecasting .metrics import NegativeBinomialDistributionLoss
285
284
285
+ object_class = object_metadata .get_model_cls ()
286
+
286
287
if "loss" in trainer_kwargs and isinstance (
287
288
trainer_kwargs ["loss" ], NegativeBinomialDistributionLoss
288
289
):
0 commit comments