Skip to content

Commit 232a510

Browse files
committed
fixes
1 parent fa8144e commit 232a510

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

pytorch_forecasting/models/base/_base_object.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def get_model_cls(cls):
2727
"""Get model class."""
2828
raise NotImplementedError
2929

30+
@classmethod
31+
def name(cls):
32+
"""Get model name."""
33+
name = cls.get_class_tags().get("info:name", None)
34+
if name is None:
35+
name = cls.get_model_cls().__name__
36+
return name
37+
3038
@classmethod
3139
def create_test_instance(cls, parameter_set="default"):
3240
"""Construct an instance of the class, using first test parameter set.

pytorch_forecasting/models/deepar/_deepar_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ class DeepARMetadata(_BasePtForecaster):
77
"""DeepAR metadata container."""
88

99
_tags = {
10+
"info:name": "DeepAR",
11+
"info:compute": 3,
12+
"authors": ["jdb78"],
1013
"capability:exogenous": True,
1114
"capability:multivariate": True,
1215
"capability:pred_int": True,
1316
"capability:flexible_history_length": True,
1417
"capability:cold_start": False,
15-
"info:compute": 3,
1618
}
1719

1820
@classmethod

pytorch_forecasting/tests/test_all_estimators.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs):
127127
object_classes_to_test = [
128128
est for est in self._all_objects() if not self.is_excluded(test_name, est)
129129
]
130-
object_names = [est.__name__ for est in object_classes_to_test]
130+
object_names = [est.name() for est in object_classes_to_test]
131131

132132
return object_classes_to_test, object_names
133133

@@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs):
156156
trainer_kwargs: dict
157157
ranges over all kwargs for the trainer
158158
"""
159-
# call _generate_object_class to get all the classes
160-
object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name)
159+
if "object_metadata" in kwargs.keys():
160+
obj_meta = kwargs["object_metadata"]
161+
else:
162+
return []
161163

162-
# create instances from the classes
163-
train_kwargs_to_test = []
164-
train_kwargs_names = []
165-
# retrieve all object parameters if multiple, construct instances
166-
for est in object_meta_to_test:
167-
est_name = est.__name__
168-
all_train_kwargs = est.get_test_train_params()
169-
train_kwargs_to_test += all_train_kwargs
170-
rg = range(len(all_train_kwargs))
171-
train_kwargs_names += [f"{est_name}_{i}" for i in rg]
164+
all_train_kwargs = obj_meta.get_test_train_params()
165+
rg = range(len(all_train_kwargs))
166+
train_kwargs_names = [str(i) for i in rg]
172167

173-
return train_kwargs_to_test, train_kwargs_names
168+
return all_train_kwargs, train_kwargs_names
174169

175170

176171
def _integration(
@@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class):
278273
doctest.run_docstring_examples(object_class, globals())
279274

280275
def test_integration(
281-
self, object_class, trainer_kwargs, data_with_covariates, tmp_path
276+
self,
277+
object_metadata,
278+
trainer_kwargs,
279+
data_with_covariates,
280+
tmp_path,
282281
):
283282
"""Fails for certain, for testing."""
284283
from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss
285284

285+
object_class = object_metadata.get_model_cls()
286+
286287
if "loss" in trainer_kwargs and isinstance(
287288
trainer_kwargs["loss"], NegativeBinomialDistributionLoss
288289
):

0 commit comments

Comments
 (0)