From bb6c8a2243c35ca35c2c0e14093d352430fee6d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 4 Sep 2024 23:08:04 +0100 Subject: [PATCH] [MNT] isolate `pytorch-optimizer` as soft dependency (#1641) Isolates `pytorch-optimizer` as soft dependency in a new soft dep set `all_extras`. https://github.com/jdb78/pytorch-forecasting/issues/1616 The imports happen only in `BaseModel`, when resolving aliases for `optimizer`. Isolation consists of two steps: * replacing resolution of the alias with a request to install the package, and removing from resolution scope if not installed * replacing the default optimizer `"ranger"` with `"adam"` if `pytorch-optimizer` is not installed (left at `"ranger"` if installed) Deprecation messages and actions are added, to changne the default to `"adam"` from 1.2.0, in order to minimize the number of dependencies in default parameter settings --- pyproject.toml | 2 +- pytorch_forecasting/models/base_model.py | 71 +++++++++++++++++++----- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06cad359..34d904b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,6 @@ dependencies = [ "scipy >=1.8,<2.0", "pandas >=1.3.0,<3.0.0", "scikit-learn >=1.2,<2.0", - "pytorch-optimizer >=2.5.1,<4.0.0", ] [project.optional-dependencies] @@ -86,6 +85,7 @@ all_extras = [ "matplotlib", "optuna >=3.1.0,<4.0.0", "optuna-integration", + "pytorch_optimizer >=2.5.1,<4.0.0", "statsmodels", ] diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 75b78078..a82196cf 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -18,8 +18,6 @@ import numpy as np from numpy import iterable import pandas as pd -import pytorch_optimizer -from pytorch_optimizer import Ranger21 import scipy.stats import torch import torch.nn as nn @@ -54,7 +52,7 @@ groupby_apply, to_list, ) -from pytorch_forecasting.utils._dependencies import _check_matplotlib +from pytorch_forecasting.utils._dependencies import _check_matplotlib, _get_installed_packages # todo: compile models @@ -411,7 +409,7 @@ def __init__( optimizer_params: Dict[str, Any] = None, monotone_constaints: Dict[str, int] = {}, output_transformer: Callable = None, - optimizer="Ranger", + optimizer=None, ): """ BaseModel for timeseries forecasting from which to inherit from @@ -445,12 +443,44 @@ def __init__( or ``pytorch_optimizer``. Alternatively, a class or function can be passed which takes parameters as first argument and a `lr` argument (optionally also `weight_decay`). Defaults to - `"ranger" `_. + `"ranger" `_, + if pytorch_optimizer is installed, otherwise "adam". """ super().__init__() # update hparams frame = inspect.currentframe() init_args = get_init_args(frame) + + # TODO 1.2.0: remove warnings and change default optimizer to "adam" + if init_args["optimizer"] is None: + ptopt_in_env = "pytorch_optimizer" in _get_installed_packages() + if ptopt_in_env: + init_args["optimizer"] = "ranger" + warnings.warn( + "In pytorch-forecasting models, from version 1.2.0, " + "the default optimizer will be 'adam', in order to " + "minimize the number of dependencies in default parameter settings. " + "Users who wish to ensure their code continues using 'ranger' as optimizer " + "should ensure that pytorch_optimizer is installed, and set the optimizer " + "parameter explicitly to 'ranger'.", + stacklevel=2, + ) + else: + init_args["optimizer"] = "adam" + warnings.warn( + "In pytorch-forecasting models, on versions 1.1.X, " + "the default optimizer defaults to 'adam', " + "if pytorch_optimizer is not installed, " + "otherwise it defaults to 'ranger' from pytorch_optimizer. " + "From version 1.2.0, the default optimizer will be 'adam' " + "regardless of whether pytorch_optimizer is installed, in order to " + "minimize the number of dependencies in default parameter settings. " + "Users who wish to ensure their code continues using 'ranger' as optimizer " + "should ensure that pytorch_optimizer is installed, and set the optimizer " + "parameter explicitly to 'ranger'.", + stacklevel=2, + ) + self.save_hyperparameters( {name: val for name, val in init_args.items() if name not in self.hparams and name not in ["self"]} ) @@ -1150,6 +1180,7 @@ def configure_optimizers(self): Returns: Tuple[List]: first entry is list of optimizers and second is list of schedulers """ + ptopt_in_env = "pytorch_optimizer" in _get_installed_packages() # either set a schedule of lrs or find it dynamically if self.hparams.optimizer_params is None: optimizer_params = {} @@ -1177,6 +1208,13 @@ def configure_optimizers(self): self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params ) elif self.hparams.optimizer == "ranger": + if not ptopt_in_env: + raise ImportError( + "optimizer 'ranger' requires pytorch_optimizer in the evironment. " + "Please install pytorch_optimizer with `pip install pytorch_optimizer`." + ) + from pytorch_optimizer import Ranger21 + if any([isinstance(c, LearningRateFinder) for c in self.trainer.callbacks]): # if finding learning rate, switch off warm up and cool down optimizer_params.setdefault("num_warm_up_iterations", 0) @@ -1203,15 +1241,20 @@ def configure_optimizers(self): ) except TypeError: # in case there is no weight decay optimizer = getattr(torch.optim, self.hparams.optimizer)(self.parameters(), lr=lr, **optimizer_params) - elif hasattr(pytorch_optimizer, self.hparams.optimizer): - try: - optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)( - self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params - ) - except TypeError: # in case there is no weight decay - optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)( - self.parameters(), lr=lr, **optimizer_params - ) + elif ptopt_in_env: + import pytorch_optimizer + + if hasattr(pytorch_optimizer, self.hparams.optimizer): + try: + optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)( + self.parameters(), lr=lr, weight_decay=self.hparams.weight_decay, **optimizer_params + ) + except TypeError: # in case there is no weight decay + optimizer = getattr(pytorch_optimizer, self.hparams.optimizer)( + self.parameters(), lr=lr, **optimizer_params + ) + else: + raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown") else: raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown")