Skip to content

Commit bde7169

Browse files
authored
[MNT] add minimal dependency management utilities (#1628)
Adds minimal utilities for dependency management, to determine the packages installed. This will be used later in isolation of dependencies that could be soft dependencies. Instead of dumping the new utils in the current `utils` file, a new folder `utils` is added, in which the current `utils` is moved one level lower, and a `_dependencies` submodule is also added.
1 parent 1a2af7d commit bde7169

File tree

10 files changed

+114
-7
lines changed

10 files changed

+114
-7
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ jobs:
103103
- name: Set up Python
104104
uses: actions/setup-python@v1
105105
with:
106-
python-version: 3.8
106+
python-version: 3.11
107107

108108
- name: Cache pip
109109
uses: actions/cache@v2

.readthedocs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ version: 2
99
# reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx
1010
sphinx:
1111
configuration: docs/source/conf.py
12-
fail_on_warning: true
12+
# fail_on_warning: true
1313

1414
# Build documentation with MkDocs
1515
#mkdocs:
@@ -21,6 +21,6 @@ formats:
2121

2222
# Optionally set the version of Python and requirements required to build your docs
2323
python:
24-
version: 3.8
24+
version: 3.11
2525
install:
2626
- requirements: docs/requirements.txt

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ nbconvert >=6.3.0
1818
recommonmark >=0.7.1
1919
pytorch-optimizer >=2.5.1
2020
fastapi >0.80
21+
cpflows

docs/source/_templates/custom-module-template.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@
5353
{% endblock %}
5454

5555
{% block modules %}
56-
{% if modules %}
56+
{% if all_modules %}
5757
.. rubric:: Modules
5858

5959
.. autosummary::
6060
:toctree:
6161
:template: custom-module-template.rst
6262
:recursive:
63-
{% for item in modules %}
63+
{% for item in all_modules %}
6464
{{ item }}
6565
{%- endfor %}
6666
{% endif %}

docs/source/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ def setup(app: Sphinx):
176176
intersphinx_mapping = {
177177
"sklearn": ("https://scikit-learn.org/stable/", None),
178178
}
179+
180+
suppress_warnings = [
181+
"autosummary.import_cycle",
182+
]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ dev = [
9898
"pytest-dotenv>=0.5.2,<1.0.0",
9999
"tensorboard>=2.12.1,<3.0.0",
100100
"pandoc>=2.3,<3.0.0",
101+
"cpflows",
101102
]
102103

103104
github-actions = ["pytest-github-actions-annotate-failures"]

pytorch_forecasting/metrics/_mqf2_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ class DeepConvexNet(DeepConvexFlow):
1212
r"""
1313
Class that takes a partially input convex neural network (picnn)
1414
as input and equips it with functions of logdet
15-
computation (both estimation and exact computation)
15+
computation (both estimation and exact computation).
1616
This class is based on DeepConvexFlow of the CP-Flow
1717
repo (https://github.com/CW-Huang/CP-Flow)
1818
For details of the logdet estimator, see
1919
``Convex potential flows: Universal probability distributions
2020
with optimal transport and convex optimization``
21+
2122
Parameters
2223
----------
2324
picnn
@@ -94,6 +95,7 @@ class SequentialNet(SequentialFlow):
9495
layers and provides energy score computation
9596
This class is based on SequentialFlow of the CP-Flow repo
9697
(https://github.com/CW-Huang/CP-Flow)
98+
9799
Parameters
98100
----------
99101
networks
@@ -116,6 +118,7 @@ def es_sample(self, hidden_state: torch.Tensor, dimension: int) -> torch.Tensor:
116118
"""
117119
Auxiliary function for energy score computation
118120
Drawing samples conditioned on the hidden state
121+
119122
Parameters
120123
----------
121124
hidden_state
@@ -159,6 +162,7 @@ def energy_score(
159162
h_i is the hidden state associated with z_i,
160163
and es_num_samples is the number of samples drawn
161164
for each of w, w', w'' in energy score approximation
165+
162166
Parameters
163167
----------
164168
z
@@ -224,6 +228,7 @@ class MQF2Distribution(Distribution):
224228
Distribution class for the model MQF2 proposed in the paper
225229
``Multivariate Quantile Function Forecaster``
226230
by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus
231+
227232
Parameters
228233
----------
229234
picnn
@@ -290,6 +295,7 @@ def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor:
290295
over the observations z
291296
Then, reshapes the observations into a 2-dimensional tensor for
292297
further computation
298+
293299
Parameters
294300
----------
295301
z
@@ -317,6 +323,7 @@ def log_prob(self, z: torch.Tensor) -> torch.Tensor:
317323
"""
318324
Computes the log likelihood log(g(z)) + logdet(dg(z)/dz),
319325
where g is the gradient of the picnn
326+
320327
Parameters
321328
----------
322329
z
@@ -346,6 +353,7 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
346353
h_i is the hidden state associated with z_i,
347354
and es_num_samples is the number of samples drawn
348355
for each of w, w', w'' in energy score approximation
356+
349357
Parameters
350358
----------
351359
z
@@ -370,14 +378,15 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
370378
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
371379
"""
372380
Generates the sample paths
381+
373382
Parameters
374383
----------
375384
sample_shape
376385
Shape of the samples
377386
Returns
378387
-------
379388
sample_paths
380-
Tesnor of shape (batch_size, *sample_shape, prediction_length)
389+
Tesnor of shape (batch_size, * sample_shape, prediction_length)
381390
"""
382391

383392
numel_batch = self.numel_batch
@@ -407,6 +416,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
407416
def quantile(self, alpha: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> torch.Tensor:
408417
"""
409418
Generates the predicted paths associated with the quantile levels alpha
419+
410420
Parameters
411421
----------
412422
alpha

pytorch_forecasting/utils/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
PyTorch Forecasting package for timeseries forecasting with PyTorch.
3+
"""
4+
5+
from pytorch_forecasting.utils._utils import (
6+
InitialParameterRepresenterMixIn,
7+
OutputMixIn,
8+
TupleOutputMixIn,
9+
apply_to_list,
10+
autocorrelation,
11+
concat_sequences,
12+
create_mask,
13+
detach,
14+
get_embedding_size,
15+
groupby_apply,
16+
integer_histogram,
17+
masked_op,
18+
move_to_device,
19+
padded_stack,
20+
profile,
21+
redirect_stdout,
22+
repr_class,
23+
to_list,
24+
unpack_sequence,
25+
unsqueeze_like,
26+
)
27+
28+
__all__ = [
29+
"InitialParameterRepresenterMixIn",
30+
"OutputMixIn",
31+
"TupleOutputMixIn",
32+
"apply_to_list",
33+
"autocorrelation",
34+
"get_embedding_size",
35+
"concat_sequences",
36+
"create_mask",
37+
"to_list",
38+
"RecurrentNetwork",
39+
"DecoderMLP",
40+
"detach",
41+
"masked_op",
42+
"move_to_device",
43+
"integer_histogram",
44+
"groupby_apply",
45+
"padded_stack",
46+
"profile",
47+
"redirect_stdout",
48+
"repr_class",
49+
"unpack_sequence",
50+
"unsqueeze_like",
51+
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Utilities for managing dependencies.
2+
3+
Copied from sktime/skbase.
4+
"""
5+
6+
from functools import lru_cache
7+
8+
9+
@lru_cache
10+
def _get_installed_packages_private():
11+
"""Get a dictionary of installed packages and their versions.
12+
13+
Same as _get_installed_packages, but internal to avoid mutating the lru_cache
14+
by accident.
15+
"""
16+
from importlib.metadata import distributions, version
17+
18+
dists = distributions()
19+
package_names = {dist.metadata["Name"] for dist in dists}
20+
package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names}
21+
# developer note:
22+
# we cannot just use distributions naively,
23+
# because the same top level package name may appear *twice*,
24+
# e.g., in a situation where a virtual env overrides a base env,
25+
# such as in deployment environments like databricks.
26+
# the "version" contract ensures we always get the version that corresponds
27+
# to the importable distribution, i.e., the top one in the sys.path.
28+
return package_versions
29+
30+
31+
def _get_installed_packages():
32+
"""Get a dictionary of installed packages and their versions.
33+
34+
Returns
35+
-------
36+
dict : dictionary of installed packages and their versions
37+
keys are PEP 440 compatible package names, values are package versions
38+
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
39+
"""
40+
return _get_installed_packages_private().copy()
File renamed without changes.

0 commit comments

Comments
 (0)