Skip to content

Commit

Permalink
test integration
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Feb 23, 2025
1 parent c9f12db commit fa8144e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
25 changes: 2 additions & 23 deletions pytorch_forecasting/models/deepar/_deepar_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,8 @@ def get_model_cls(cls):
return DeepAR

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the skbase object.
``get_test_params`` is a unified interface point to store
parameter settings for testing purposes. This function is also
used in ``create_test_instance`` and ``create_test_instances_and_names``
to construct test instances.
``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
Each ``dict`` is a parameter configuration for testing,
and can be used to construct an "interesting" test instance.
A call to ``cls(**params)`` should
be valid for all dictionaries ``params`` in the return of ``get_test_params``.
The ``get_test_params`` need not return fixed lists of dictionaries,
it can also return dynamic or stochastic parameter settings.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
def get_test_train_params(cls):
"""Return testing parameter settings for the trainer.
Returns
-------
Expand Down
41 changes: 39 additions & 2 deletions pytorch_forecasting/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
object_instance: instance of estimator inheriting from BaseObject
ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
instances are generated by create_test_instance class method of object_class
trainer_kwargs: list of dict
ranges over dictionaries of kwargs for the trainer
"""

# overrides object retrieval in scikit-base
Expand Down Expand Up @@ -111,6 +113,7 @@ def _all_objects(self):
"object_metadata",
"object_class",
"object_instance",
"trainer_kwargs",
]

def _generate_object_metadata(self, test_name, **kwargs):
Expand Down Expand Up @@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs):

return object_classes_to_test, object_names

def _generate_trainer_kwargs(self, test_name, **kwargs):
"""Return kwargs for the trainer.
Fixtures parametrized
---------------------
trainer_kwargs: dict
ranges over all kwargs for the trainer
"""
# call _generate_object_class to get all the classes
object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name)

# create instances from the classes
train_kwargs_to_test = []
train_kwargs_names = []
# retrieve all object parameters if multiple, construct instances
for est in object_meta_to_test:
est_name = est.__name__
all_train_kwargs = est.get_test_train_params()
train_kwargs_to_test += all_train_kwargs
rg = range(len(all_train_kwargs))
train_kwargs_names += [f"{est_name}_{i}" for i in rg]

return train_kwargs_to_test, train_kwargs_names


def _integration(
estimator_cls,
Expand Down Expand Up @@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class):

doctest.run_docstring_examples(object_class, globals())

def test_certain_failure(self, object_class):
def test_integration(
self, object_class, trainer_kwargs, data_with_covariates, tmp_path
):
"""Fails for certain, for testing."""
assert False
from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss

if "loss" in trainer_kwargs and isinstance(
trainer_kwargs["loss"], NegativeBinomialDistributionLoss
):
data_with_covariates = data_with_covariates.assign(
volume=lambda x: x.volume.round()
)
_integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs)

0 comments on commit fa8144e

Please sign in to comment.