Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Feb 23, 2025
1 parent fa8144e commit 232a510
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
8 changes: 8 additions & 0 deletions pytorch_forecasting/models/base/_base_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def get_model_cls(cls):
"""Get model class."""
raise NotImplementedError

@classmethod
def name(cls):
"""Get model name."""
name = cls.get_class_tags().get("info:name", None)
if name is None:
name = cls.get_model_cls().__name__
return name

@classmethod
def create_test_instance(cls, parameter_set="default"):
"""Construct an instance of the class, using first test parameter set.
Expand Down
4 changes: 3 additions & 1 deletion pytorch_forecasting/models/deepar/_deepar_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ class DeepARMetadata(_BasePtForecaster):
"""DeepAR metadata container."""

_tags = {
"info:name": "DeepAR",
"info:compute": 3,
"authors": ["jdb78"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
"info:compute": 3,
}

@classmethod
Expand Down
31 changes: 16 additions & 15 deletions pytorch_forecasting/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs):
object_classes_to_test = [
est for est in self._all_objects() if not self.is_excluded(test_name, est)
]
object_names = [est.__name__ for est in object_classes_to_test]
object_names = [est.name() for est in object_classes_to_test]

return object_classes_to_test, object_names

Expand Down Expand Up @@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs):
trainer_kwargs: dict
ranges over all kwargs for the trainer
"""
# call _generate_object_class to get all the classes
object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name)
if "object_metadata" in kwargs.keys():
obj_meta = kwargs["object_metadata"]
else:
return []

# create instances from the classes
train_kwargs_to_test = []
train_kwargs_names = []
# retrieve all object parameters if multiple, construct instances
for est in object_meta_to_test:
est_name = est.__name__
all_train_kwargs = est.get_test_train_params()
train_kwargs_to_test += all_train_kwargs
rg = range(len(all_train_kwargs))
train_kwargs_names += [f"{est_name}_{i}" for i in rg]
all_train_kwargs = obj_meta.get_test_train_params()
rg = range(len(all_train_kwargs))
train_kwargs_names = [str(i) for i in rg]

return train_kwargs_to_test, train_kwargs_names
return all_train_kwargs, train_kwargs_names


def _integration(
Expand Down Expand Up @@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class):
doctest.run_docstring_examples(object_class, globals())

def test_integration(
self, object_class, trainer_kwargs, data_with_covariates, tmp_path
self,
object_metadata,
trainer_kwargs,
data_with_covariates,
tmp_path,
):
"""Fails for certain, for testing."""
from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss

object_class = object_metadata.get_model_cls()

if "loss" in trainer_kwargs and isinstance(
trainer_kwargs["loss"], NegativeBinomialDistributionLoss
):
Expand Down

0 comments on commit 232a510

Please sign in to comment.