Skip to content

Commit fa8144e

Browse files
committed
test integration
1 parent c9f12db commit fa8144e

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

pytorch_forecasting/models/deepar/_deepar_metadata.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,8 @@ def get_model_cls(cls):
2323
return DeepAR
2424

2525
@classmethod
26-
def get_test_params(cls, parameter_set="default"):
27-
"""Return testing parameter settings for the skbase object.
28-
29-
``get_test_params`` is a unified interface point to store
30-
parameter settings for testing purposes. This function is also
31-
used in ``create_test_instance`` and ``create_test_instances_and_names``
32-
to construct test instances.
33-
34-
``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
35-
36-
Each ``dict`` is a parameter configuration for testing,
37-
and can be used to construct an "interesting" test instance.
38-
A call to ``cls(**params)`` should
39-
be valid for all dictionaries ``params`` in the return of ``get_test_params``.
40-
41-
The ``get_test_params`` need not return fixed lists of dictionaries,
42-
it can also return dynamic or stochastic parameter settings.
43-
44-
Parameters
45-
----------
46-
parameter_set : str, default="default"
47-
Name of the set of test parameters to return, for use in tests. If no
48-
special parameters are defined for a value, will return `"default"` set.
26+
def get_test_train_params(cls):
27+
"""Return testing parameter settings for the trainer.
4928
5029
Returns
5130
-------

pytorch_forecasting/tests/test_all_estimators.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator):
7171
object_instance: instance of estimator inheriting from BaseObject
7272
ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
7373
instances are generated by create_test_instance class method of object_class
74+
trainer_kwargs: list of dict
75+
ranges over dictionaries of kwargs for the trainer
7476
"""
7577

7678
# overrides object retrieval in scikit-base
@@ -111,6 +113,7 @@ def _all_objects(self):
111113
"object_metadata",
112114
"object_class",
113115
"object_instance",
116+
"trainer_kwargs",
114117
]
115118

116119
def _generate_object_metadata(self, test_name, **kwargs):
@@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs):
145148

146149
return object_classes_to_test, object_names
147150

151+
def _generate_trainer_kwargs(self, test_name, **kwargs):
152+
"""Return kwargs for the trainer.
153+
154+
Fixtures parametrized
155+
---------------------
156+
trainer_kwargs: dict
157+
ranges over all kwargs for the trainer
158+
"""
159+
# call _generate_object_class to get all the classes
160+
object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name)
161+
162+
# create instances from the classes
163+
train_kwargs_to_test = []
164+
train_kwargs_names = []
165+
# retrieve all object parameters if multiple, construct instances
166+
for est in object_meta_to_test:
167+
est_name = est.__name__
168+
all_train_kwargs = est.get_test_train_params()
169+
train_kwargs_to_test += all_train_kwargs
170+
rg = range(len(all_train_kwargs))
171+
train_kwargs_names += [f"{est_name}_{i}" for i in rg]
172+
173+
return train_kwargs_to_test, train_kwargs_names
174+
148175

149176
def _integration(
150177
estimator_cls,
@@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class):
250277

251278
doctest.run_docstring_examples(object_class, globals())
252279

253-
def test_certain_failure(self, object_class):
280+
def test_integration(
281+
self, object_class, trainer_kwargs, data_with_covariates, tmp_path
282+
):
254283
"""Fails for certain, for testing."""
255-
assert False
284+
from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss
285+
286+
if "loss" in trainer_kwargs and isinstance(
287+
trainer_kwargs["loss"], NegativeBinomialDistributionLoss
288+
):
289+
data_with_covariates = data_with_covariates.assign(
290+
volume=lambda x: x.volume.round()
291+
)
292+
_integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs)

0 commit comments

Comments
 (0)