Skip to content

Commit

Permalink
[MNT] isolate pytorch-optimizer as soft dependency (#1641)
Browse files Browse the repository at this point in the history
Isolates `pytorch-optimizer` as soft dependency in a new soft dep set `all_extras`. #1616

The imports happen only in `BaseModel`, when resolving aliases for `optimizer`.

Isolation consists of two steps:
* replacing resolution of the alias with a request to install the package, and removing from resolution scope if not installed
* replacing the default optimizer `"ranger"` with `"adam"` if `pytorch-optimizer` is not installed (left at `"ranger"` if installed)

Deprecation messages and actions are added, to changne the default to `"adam"` from 1.2.0, in order to minimize the number of dependencies in default parameter settings
  • Loading branch information
fkiraly authored Sep 4, 2024
1 parent 95fa06c commit bb6c8a2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ dependencies = [
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
"pytorch-optimizer >=2.5.1,<4.0.0",
]

[project.optional-dependencies]
Expand All @@ -86,6 +85,7 @@ all_extras = [
"matplotlib",
"optuna >=3.1.0,<4.0.0",
"optuna-integration",
"pytorch_optimizer >=2.5.1,<4.0.0",
"statsmodels",
]

Expand Down
71 changes: 57 additions & 14 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import numpy as np
from numpy import iterable
import pandas as pd
import pytorch_optimizer
from pytorch_optimizer import Ranger21
import scipy.stats
import torch
import torch.nn as nn
Expand Down Expand Up @@ -54,7 +52,7 @@
groupby_apply,
to_list,
)
from pytorch_forecasting.utils._dependencies import _check_matplotlib
from pytorch_forecasting.utils._dependencies import _check_matplotlib, _get_installed_packages

# todo: compile models

Expand Down Expand Up @@ -411,7 +409,7 @@ def __init__(
optimizer_params: Dict[str, Any] = None,
monotone_constaints: Dict[str, int] = {},
output_transformer: Callable = None,
optimizer="Ranger",
optimizer=None,
):
"""
BaseModel for timeseries forecasting from which to inherit from
Expand Down Expand Up @@ -445,12 +443,44 @@ def __init__(
or ``pytorch_optimizer``.
Alternatively, a class or function can be passed which takes parameters as first argument and
a `lr` argument (optionally also `weight_decay`). Defaults to
`"ranger" <https://pytorch-optimizers.readthedocs.io/en/latest/optimizer_api.html#ranger21>`_.
`"ranger" <https://pytorch-optimizers.readthedocs.io/en/latest/optimizer_api.html#ranger21>`_,
if pytorch_optimizer is installed, otherwise "adam".
"""
super().__init__()
# update hparams
frame = inspect.currentframe()
init_args = get_init_args(frame)

# TODO 1.2.0: remove warnings and change default optimizer to "adam"
if init_args["optimizer"] is None:
ptopt_in_env = "pytorch_optimizer" in _get_installed_packages()
if ptopt_in_env:
init_args["optimizer"] = "ranger"
warnings.warn(
"In pytorch-forecasting models, from version 1.2.0, "
"the default optimizer will be 'adam', in order to "
"minimize the number of dependencies in default parameter settings. "
"Users who wish to ensure their code continues using 'ranger' as optimizer "
"should ensure that pytorch_optimizer is installed, and set the optimizer "
"parameter explicitly to 'ranger'.",
stacklevel=2,
)
else:
init_args["optimizer"] = "adam"
warnings.warn(
"In pytorch-forecasting models, on versions 1.1.X, "
"the default optimizer defaults to 'adam', "
"if pytorch_optimizer is not installed, "
"otherwise it defaults to 'ranger' from pytorch_optimizer. "
"From version 1.2.0, the default optimizer will be 'adam' "
"regardless of whether pytorch_optimizer is installed, in order to "
"minimize the number of dependencies in default parameter settings. "
"Users who wish to ensure their code continues using 'ranger' as optimizer "
"should ensure that pytorch_optimizer is installed, and set the optimizer "
"parameter explicitly to 'ranger'.",
stacklevel=2,
)

self.save_hyperparameters(
{name: val for name, val in init_args.items() if name not in self.hparams and name not in ["self"]}
)
Expand Down Expand Up @@ -1150,6 +1180,7 @@ def configure_optimizers(self):
Returns:
Tuple[List]: first entry is list of optimizers and second is list of schedulers
"""
ptopt_in_env = "pytorch_optimizer" in _get_installed_packages()
# either set a schedule of lrs or find it dynamically
if self.hparams.optimizer_params is None:
optimizer_params = {}
Expand Down Expand Up @@ -1177,6 +1208,13 @@ def configure_optimizers(self):
self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params
)
elif self.hparams.optimizer == "ranger":
if not ptopt_in_env:
raise ImportError(
"optimizer 'ranger' requires pytorch_optimizer in the evironment. "
"Please install pytorch_optimizer with `pip install pytorch_optimizer`."
)
from pytorch_optimizer import Ranger21

if any([isinstance(c, LearningRateFinder) for c in self.trainer.callbacks]):
# if finding learning rate, switch off warm up and cool down
optimizer_params.setdefault("num_warm_up_iterations", 0)
Expand All @@ -1203,15 +1241,20 @@ def configure_optimizers(self):
)
except TypeError: # in case there is no weight decay
optimizer = getattr(torch.optim, self.hparams.optimizer)(self.parameters(), lr=lr, **optimizer_params)
elif hasattr(pytorch_optimizer, self.hparams.optimizer):
try:
optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)(
self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params
)
except TypeError: # in case there is no weight decay
optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)(
self.parameters(), lr=lr, **optimizer_params
)
elif ptopt_in_env:
import pytorch_optimizer

if hasattr(pytorch_optimizer, self.hparams.optimizer):
try:
optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)(
self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params
)
except TypeError: # in case there is no weight decay
optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)(
self.parameters(), lr=lr, **optimizer_params
)
else:
raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown")
else:
raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown")

Expand Down

0 comments on commit bb6c8a2

Please sign in to comment.