From b8053fe7e8a0a5e19cad966dc9a24825de7feba8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 14:38:29 +0000 Subject: [PATCH 01/10] Update [ghstack-poisoned] --- .../unittest/helpers/coverage_run_parallel.py | 4 +- .pre-commit-config.yaml | 6 ++ build_tools/setup_helpers/extension.py | 5 +- setup.py | 8 +- sota-implementations/a2c/a2c_atari.py | 2 +- sota-implementations/a2c/a2c_mujoco.py | 2 +- sota-implementations/cql/cql_offline.py | 5 +- sota-implementations/cql/cql_online.py | 4 +- .../cql/discrete_cql_online.py | 6 +- sota-implementations/crossq/crossq.py | 6 +- sota-implementations/ddpg/ddpg.py | 5 +- .../decision_transformer/dt.py | 4 +- .../decision_transformer/online_dt.py | 3 +- .../discrete_sac/discrete_sac.py | 2 +- sota-implementations/dqn/dqn_atari.py | 3 +- sota-implementations/dqn/dqn_cartpole.py | 3 +- sota-implementations/dreamer/dreamer.py | 4 +- sota-implementations/gail/gail.py | 6 +- sota-implementations/gail/gail_utils.py | 4 +- .../impala/impala_multi_node_ray.py | 2 +- .../impala/impala_multi_node_submitit.py | 2 +- .../impala/impala_single_node.py | 2 +- sota-implementations/iql/discrete_iql.py | 6 +- sota-implementations/iql/iql_offline.py | 6 +- sota-implementations/iql/iql_online.py | 6 +- sota-implementations/multiagent/iql.py | 3 +- .../multiagent/maddpg_iddpg.py | 3 +- sota-implementations/multiagent/mappo_ippo.py | 3 +- sota-implementations/multiagent/qmix_vdn.py | 3 +- sota-implementations/multiagent/sac.py | 3 +- sota-implementations/ppo/ppo_atari.py | 3 +- sota-implementations/ppo/ppo_mujoco.py | 3 +- sota-implementations/redq/redq.py | 2 +- sota-implementations/redq/utils.py | 38 ++++--- sota-implementations/sac/sac.py | 5 +- sota-implementations/td3/td3.py | 6 +- sota-implementations/td3_bc/td3_bc.py | 5 +- test/_utils_internal.py | 24 ++--- test/mocking_classes.py | 53 +++++----- test/opengl_rendering.py | 1 - test/test_collector.py | 30 ++---- test/test_loggers.py | 13 ++- test/test_specs.py | 7 +- torchrl/_utils.py | 16 ++- torchrl/collectors/collectors.py | 99 ++++++++----------- torchrl/collectors/distributed/generic.py | 29 +++--- torchrl/collectors/distributed/ray.py | 38 ++++--- torchrl/collectors/distributed/rpc.py | 19 ++-- torchrl/collectors/distributed/sync.py | 24 +++-- torchrl/data/datasets/atari_dqn.py | 9 +- torchrl/data/datasets/d4rl.py | 12 +-- torchrl/data/datasets/minari_data.py | 6 +- torchrl/data/datasets/openml.py | 3 +- torchrl/data/datasets/openx.py | 20 ++-- torchrl/data/datasets/roboset.py | 4 +- torchrl/data/datasets/vd4rl.py | 12 +-- torchrl/data/map/hash.py | 4 +- torchrl/data/map/query.py | 8 +- torchrl/data/map/tdstorage.py | 18 ++-- torchrl/data/map/tree.py | 51 ++++------ torchrl/data/map/utils.py | 8 +- torchrl/data/replay_buffers/checkpointers.py | 12 +-- torchrl/data/replay_buffers/replay_buffers.py | 51 +++++----- torchrl/data/replay_buffers/samplers.py | 69 +++++++------ torchrl/data/replay_buffers/scheduler.py | 7 +- torchrl/data/replay_buffers/storages.py | 57 +++++------ torchrl/data/replay_buffers/utils.py | 11 +-- torchrl/data/replay_buffers/writers.py | 29 +++--- torchrl/data/rlhf/dataset.py | 11 +-- torchrl/data/rlhf/prompt.py | 11 +-- torchrl/data/rlhf/reward.py | 7 +- torchrl/data/rlhf/utils.py | 13 +-- torchrl/data/utils.py | 6 +- torchrl/envs/batched_envs.py | 67 ++++++------- torchrl/envs/custom/chess.py | 33 +++---- torchrl/envs/custom/tictactoeenv.py | 5 +- torchrl/envs/env_creator.py | 10 +- torchrl/envs/gym_like.py | 33 +++---- torchrl/envs/libs/dm_control.py | 36 +++---- torchrl/envs/libs/envpool.py | 41 ++++---- torchrl/envs/libs/gym.py | 24 ++--- torchrl/envs/libs/isaacgym.py | 14 ++- torchrl/envs/libs/jumanji.py | 26 +++-- torchrl/envs/libs/meltingpot.py | 43 ++++---- torchrl/envs/libs/openspiel.py | 16 ++- torchrl/envs/libs/pettingzoo.py | 63 ++++++------ torchrl/envs/libs/robohive.py | 10 +- torchrl/envs/libs/unity_mlagents.py | 37 +++---- torchrl/envs/libs/vmas.py | 55 +++++------ torchrl/envs/model_based/common.py | 3 +- torchrl/envs/model_based/dreamer.py | 5 +- torchrl/envs/transforms/functional.py | 2 +- torchrl/envs/transforms/rb_transforms.py | 7 +- torchrl/envs/transforms/vip.py | 5 +- torchrl/modules/distributions/continuous.py | 31 +++--- .../modules/distributions/truncated_normal.py | 12 +-- torchrl/modules/distributions/utils.py | 2 +- .../modules/models/decision_transformer.py | 3 +- torchrl/modules/models/exploration.py | 31 +++--- torchrl/modules/models/models.py | 60 ++++++----- torchrl/modules/models/multiagent.py | 29 +++--- torchrl/modules/models/recipes/impala.py | 2 +- torchrl/modules/models/utils.py | 4 +- torchrl/modules/tensordict_module/actors.py | 86 ++++++++-------- torchrl/modules/tensordict_module/common.py | 36 ++++--- .../modules/tensordict_module/exploration.py | 45 ++++----- .../tensordict_module/probabilistic.py | 14 ++- torchrl/modules/tensordict_module/rnn.py | 30 +++--- torchrl/objectives/a2c.py | 24 ++--- torchrl/objectives/common.py | 14 ++- torchrl/objectives/cql.py | 28 +++--- torchrl/objectives/crossq.py | 15 ++- torchrl/objectives/ddpg.py | 10 +- torchrl/objectives/decision_transformer.py | 7 +- torchrl/objectives/deprecated.py | 20 ++-- torchrl/objectives/dqn.py | 16 ++- torchrl/objectives/dreamer.py | 15 ++- torchrl/objectives/iql.py | 26 +++-- torchrl/objectives/multiagent/qmixer.py | 20 ++-- torchrl/objectives/ppo.py | 43 ++++---- torchrl/objectives/redq.py | 14 +-- torchrl/objectives/sac.py | 37 ++++--- torchrl/objectives/td3.py | 18 ++-- torchrl/objectives/td3_bc.py | 16 ++- torchrl/objectives/utils.py | 40 ++++---- torchrl/objectives/value/advantages.py | 33 +++---- torchrl/objectives/value/functional.py | 27 +++-- torchrl/record/loggers/csv.py | 13 ++- torchrl/record/recorder.py | 21 ++-- torchrl/trainers/trainers.py | 90 ++++++++--------- tutorials/sphinx-tutorials/coding_dqn.py | 2 +- 131 files changed, 1026 insertions(+), 1418 deletions(-) diff --git a/.github/unittest/helpers/coverage_run_parallel.py b/.github/unittest/helpers/coverage_run_parallel.py index 8c6251cf82b..9b97b848f53 100644 --- a/.github/unittest/helpers/coverage_run_parallel.py +++ b/.github/unittest/helpers/coverage_run_parallel.py @@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None: argv: Arguments passed to this script, which need to be converted to config file entries """ assert not config_path.exists(), "Temporary coverage config exists already" - cmdline = " ".join(shlex.quote(arg) for arg in argv[1:]) - with open(str(config_path), "wt", encoding="utf-8") as fh: + cmdline = shlex.join(argv[1:]) + with open(str(config_path), "w", encoding="utf-8") as fh: fh.write( f"""# .coveragerc to control coverage.py [run] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 532445125aa..f31b7c41f7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,3 +35,9 @@ repos: hooks: - id: pydocstyle files: ^torchrl/ + + - repo: https://github.com/asottile/pyupgrade + rev: v3.9.0 + hooks: + - id: pyupgrade + args: [--py38-plus] diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 6e950caa237..973238763fb 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -8,13 +8,12 @@ import platform import subprocess from pathlib import Path -from subprocess import CalledProcessError, check_output, STDOUT +from subprocess import CalledProcessError, STDOUT, check_output import torch from setuptools import Extension from setuptools.command.build_ext import build_ext - _THIS_DIR = Path(__file__).parent.resolve() _ROOT_DIR = _THIS_DIR.parent.parent.resolve() _TORCHRL_DIR = _ROOT_DIR / "torchrl" @@ -130,7 +129,7 @@ def build_extension(self, ext): # using -j in the build_ext call, not supported by pip or PyPA-build. if hasattr(self, "parallel") and self.parallel: # CMake 3.12+ only. - build_args += ["-j{}".format(self.parallel)] + build_args += [f"-j{self.parallel}"] if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) diff --git a/setup.py b/setup.py index aebbf09037a..252b67e3187 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_version(): version_txt = os.path.join(cwd, "version.txt") - with open(version_txt, "r") as f: + with open(version_txt) as f: version = f.readline().strip() if os.getenv("TORCHRL_BUILD_VERSION"): version = os.getenv("TORCHRL_BUILD_VERSION") @@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace: def write_version_file(version): version_path = os.path.join(cwd, "torchrl", "version.py") with open(version_path, "w") as f: - f.write("__version__ = '{}'\n".format(version)) - f.write("git_version = {}\n".format(repr(sha))) + f.write(f"__version__ = '{version}'\n") + f.write(f"git_version = {repr(sha)}\n") def _get_pytorch_version(is_nightly, is_local): @@ -185,7 +185,7 @@ def _main(argv): version = get_version() write_version_file(version) TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION") - logging.info("Building wheel {}-{}".format(package_name, version)) + logging.info(f"Building wheel {package_name}-{version}") logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}") is_local = TORCHRL_BUILD_VERSION is None diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 3279d6e0a2b..4d12a75ea0f 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -13,7 +13,7 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 from copy import deepcopy diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 41e05dc1326..d07ee6621af 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -13,7 +13,7 @@ @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 from copy import deepcopy diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 2e1a20ad7a2..fc388399878 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -15,16 +15,13 @@ import hydra import numpy as np - import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -39,7 +36,7 @@ @hydra.main(config_path="", config_name="offline_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name) logger = None diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index e992bdb5939..5d25a34ba10 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -21,12 +21,10 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -42,7 +40,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="online_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name) logger = None diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index d45ce3745fe..2f7441ee4eb 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -16,16 +16,12 @@ import hydra import numpy as np - import torch import torch.cuda import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( log_metrics, @@ -41,7 +37,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.optim.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index d84613e6876..619f2395fb1 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -15,19 +15,15 @@ import warnings import hydra - import numpy as np - import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( log_metrics, @@ -43,7 +39,7 @@ @hydra.main(version_base="1.1", config_path=".", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index bcb7ee6ef54..5b6d308aba2 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -15,16 +15,13 @@ import warnings import hydra - import numpy as np import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger @@ -41,7 +38,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.optim.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 9e8446ed82f..f565aafeafc 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -19,11 +19,9 @@ from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder - from utils import ( dump_video, log_metrics, @@ -37,7 +35,7 @@ @hydra.main(config_path="", config_name="dt_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 1404cb7ebc0..baab8bbb9a6 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -20,7 +20,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder - from utils import ( dump_video, log_metrics, @@ -34,7 +33,7 @@ @hydra.main(config_path="", config_name="odt_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 9ff50902887..1c97163b95a 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -38,7 +38,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 786e5d2ebb0..c2bffd91869 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -18,7 +18,6 @@ import tqdm from tensordict.nn import CudaGraphModule, TensorDictSequential from torchrl._utils import timeit - from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -32,7 +31,7 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.device if device in ("", None): diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 4fde452fba9..87be7fd603a 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -11,7 +11,6 @@ import torch.nn import torch.optim import tqdm - from tensordict.nn import CudaGraphModule, TensorDictSequential from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector @@ -27,7 +26,7 @@ @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.device if device in ("", None): diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a197796e978..4534f93a4f5 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -20,14 +20,12 @@ make_environments, make_replay_buffer, ) - # mixed precision training from torch.amp import GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import RSSMRollout - from torchrl.objectives.dreamer import ( DreamerActorLoss, DreamerModelLoss, @@ -37,7 +35,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # cfg = correct_for_frame_skip(cfg) device = _default_device(cfg.networks.device) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index bdb8843aaf6..c7fa393a2bd 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -17,16 +17,13 @@ import numpy as np import torch import tqdm - from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers @@ -34,12 +31,11 @@ from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() device = cfg.gail.device diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py index ce09292cc47..328a0864f31 100644 --- a/sota-implementations/gail/gail_utils.py +++ b/sota-implementations/gail/gail_utils.py @@ -6,11 +6,9 @@ import torch.nn as nn import torch.optim - from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import DoubleToFloat - from torchrl.modules import SafeModule @@ -45,7 +43,7 @@ def make_gail_discriminator(cfg, train_env, device="cpu"): # Define Discriminator Network class Discriminator(nn.Module): def __init__(self, state_dim, action_dim): - super(Discriminator, self).__init__() + super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, 1) diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index dcf908c2cd2..5364c82c7b2 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -14,7 +14,7 @@ @hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 4d90e9053bd..527821820ca 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -16,7 +16,7 @@ @hydra.main( config_path="", config_name="config_multi_node_submitit", version_base="1.1" ) -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index cda63ac0919..b7af2adbc38 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -14,7 +14,7 @@ @hydra.main(config_path="", config_name="config_single_node", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index aa4cea04024..43a8dcafa6e 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -21,14 +21,11 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -40,12 +37,11 @@ make_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="discrete_iql") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index eaf791438cc..6585534ff68 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -18,14 +18,11 @@ import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -36,12 +33,11 @@ make_offline_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="offline_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 5b90f00c467..eaa37f29176 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -20,14 +20,11 @@ import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -39,12 +36,11 @@ make_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="online_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 2692c1c24b5..56673ac9590 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -31,7 +30,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="iql") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index f04ccb19071..eb6700d766f 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -36,7 +35,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="maddpg_iddpg") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index 924ea12272a..0d80896fc9b 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn @@ -32,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="mappo_ippo") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index a832a29e6dd..4fed4fea5f5 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -31,7 +30,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="qmix_vdn") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 31106bdd2a0..cc30011f4a0 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn @@ -33,7 +32,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="sac") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 8ecb675535b..25b6f63e893 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -12,12 +12,11 @@ import warnings import hydra - from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import torch.optim import tqdm diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 27ae7e57848..a17d0b90339 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -12,12 +12,11 @@ import warnings import hydra - from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import torch.optim import tqdm diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index 3dec888145c..58072701663 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -42,7 +42,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 cfg = correct_for_frame_skip(cfg) diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index b67b02c42f9..2fc55cd853f 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -5,7 +5,7 @@ from __future__ import annotations from copy import copy -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Sequence import torch from omegaconf import OmegaConf @@ -17,10 +17,8 @@ ) from torch import distributions as d, nn, optim from torch.optim.lr_scheduler import CosineAnnealingLR - -from torchrl._utils import logger as torchrl_logger, VERBOSE +from torchrl._utils import VERBOSE, logger as torchrl_logger from torchrl.collectors.collectors import DataCollectorBase - from torchrl.data import ( LazyMemmapStorage, MultiStep, @@ -37,12 +35,10 @@ Compose, DMControlEnv, DoubleToFloat, - env_creator, EnvBase, EnvCreator, FlattenObservation, GrayScale, - gSDENoise, GymEnv, InitTracker, NoopResetEnv, @@ -54,6 +50,8 @@ ToTensorImage, TransformedEnv, VecNorm, + env_creator, + gSDENoise, ) from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( @@ -105,7 +103,7 @@ } -def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 +def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames @@ -172,7 +170,7 @@ def make_trainer( policy_exploration: TensorDictModuleWrapper | TensorDictModule | None, replay_buffer: ReplayBuffer | None, logger: Logger | None, - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 ) -> Trainer: """Creates a Trainer instance given its constituents. @@ -377,7 +375,7 @@ def make_trainer( def make_redq_model( proof_environment: EnvBase, - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 device: DEVICE_TYPING = "cpu", in_keys: Sequence[str] | None = None, actor_net_kwargs=None, @@ -555,7 +553,7 @@ def make_redq_model( def transformed_env_constructor( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 video_tag: str = "", logger: Logger | None = None, stats: dict | None = None, @@ -568,7 +566,7 @@ def transformed_env_constructor( state_dim_gsde: int | None = None, batch_dims: int | None = 0, obs_norm_state_dict: dict | None = None, -) -> Union[Callable, EnvCreator]: +) -> Callable | EnvCreator: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -688,7 +686,7 @@ def get_norm_state_dict(env): def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: Union[str, Tuple[str, ...]] = None, + key: str | tuple[str, ...] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. @@ -729,8 +727,8 @@ def initialize_observation_norm_transforms( def parallel_env_constructor( - cfg: "DictConfig", **kwargs # noqa: F821 -) -> Union[ParallelEnv, EnvCreator]: + cfg: DictConfig, **kwargs # noqa: F821 +) -> ParallelEnv | EnvCreator: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -916,9 +914,7 @@ def make_env_transforms( return env -def make_redq_loss( - model, cfg -) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: +def make_redq_loss(model, cfg) -> tuple[REDQLoss_deprecated, TargetNetUpdater | None]: """Builds the REDQ loss module.""" loss_kwargs = {} loss_kwargs.update({"loss_function": cfg.loss.loss_function}) @@ -950,7 +946,7 @@ def make_redq_loss( def make_target_updater( - cfg: "DictConfig", loss_module: LossModule # noqa: F821 + cfg: DictConfig, loss_module: LossModule # noqa: F821 ) -> TargetNetUpdater | None: """Builds a target network weight update object.""" if cfg.loss.type == "double": @@ -976,8 +972,8 @@ def make_target_updater( def make_collector_offpolicy( make_env: Callable[[], EnvBase], actor_model_explore: TensorDictModuleWrapper | ProbabilisticTensorDictSequential, - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Dict | None = None, + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Returns a data collector for off-policy sota-implementations. @@ -1037,7 +1033,7 @@ def make_collector_offpolicy( def make_replay_buffer( - device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 + device: DEVICE_TYPING, cfg: DictConfig # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = torch.device(device) diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index e159824f9cd..7fd6284037e 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -15,18 +15,15 @@ import warnings import hydra - import numpy as np import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -43,7 +40,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 3a741735a1c..f7b10e8cdf9 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -20,11 +20,8 @@ import torch.cuda import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -37,12 +34,11 @@ make_td3_agent, ) - torch.set_float32_matmul_precision("high") @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index ac65f2875cf..6c628904908 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -19,13 +19,10 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -38,7 +35,7 @@ @hydra.main(config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.library).set() # Create logger diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 5c4b9930089..89762567529 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -7,7 +7,6 @@ import contextlib import logging import os - import os.path import sys import time @@ -15,28 +14,18 @@ import warnings from functools import wraps -# Get relative file path -# this returns relative path from current file. - import pytest import torch import torch.cuda - -from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, tensorclass from tensordict.nn import TensorDictModuleBase from torch import nn, vmap -from torchrl._utils import ( - implement_for, - logger as torchrl_logger, - RL_WARNINGS, - seed_generator, -) +from torchrl._utils import (RL_WARNINGS, implement_for, logger as torchrl_logger, seed_generator) from torchrl.data.utils import CloudpickleWrapper - from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool -from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv +from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend from torchrl.envs.transforms import ( Compose, RewardClipping, @@ -46,6 +35,9 @@ from torchrl.modules import MLP from torchrl.objectives.value.advantages import _vmap_func +# Get relative file path +# this returns relative path from current file. + # Specified for test_utils.py __version__ = "0.3" @@ -671,7 +663,7 @@ def _lstm( if hidden1_in is None and hidden0_in is None: shape = (batch, steps) if not squeeze1 else (batch,) - hidden0_in, hidden1_in = [ + hidden0_in, hidden1_in = ( torch.zeros( *shape, self.lstm.num_layers, @@ -680,7 +672,7 @@ def _lstm( dtype=input.dtype, ) for _ in range(2) - ] + ) elif hidden1_in is None or hidden0_in is None: raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d407a2ac241..2b77b45df40 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -6,16 +6,13 @@ import random import string -from typing import Dict, List, Optional import numpy as np - import torch import torch.nn as nn -from tensordict import tensorclass, TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase, tensorclass from tensordict.nn import TensorDictModuleBase -from tensordict.utils import expand_right, NestedKey - +from tensordict.utils import NestedKey, expand_right from torchrl.data import ( Binary, Bounded, @@ -31,11 +28,7 @@ from torchrl.envs import Transform from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase -from torchrl.envs.utils import ( - _terminated_or_truncated, - check_marl_grouping, - MarlGroupMapType, -) +from torchrl.envs.utils import (MarlGroupMapType, _terminated_or_truncated, check_marl_grouping) spec_dict = { "bounded": Bounded, @@ -133,7 +126,7 @@ def __init__( def maxstep(self): return 100 - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self.seed = seed self.counter = seed % 17 # make counter a small number @@ -218,10 +211,10 @@ def __new__( return super().__new__(*args, **kwargs) def __init__(self, device="cpu"): - super(MockSerialEnv, self).__init__(device=device) + super().__init__(device=device) self.is_closed = False - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number @@ -259,7 +252,7 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: device=self.device, ) - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: + def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: return self.step(tensordict) @@ -338,12 +331,12 @@ def __new__( return super().__new__(cls, *args, **kwargs) def __init__(self, device="cpu", batch_size=None): - super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size) + super().__init__(device=device, batch_size=batch_size) self.counter = 0 rand_step = MockSerialEnv.rand_step - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number @@ -422,9 +415,7 @@ class MockBatchedUnLockedEnv(MockBatchedLockedEnv): """ def __init__(self, device="cpu", batch_size=None): - super(MockBatchedUnLockedEnv, self).__init__( - batch_size=batch_size, device=device - ) + super().__init__(batch_size=batch_size, device=device) @classmethod def __new__(cls, *args, **kwargs): @@ -510,7 +501,7 @@ def _step( device=tensordict.device, ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... @@ -1113,7 +1104,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1206,7 +1197,7 @@ def __init__( self, n_agents: int, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, max_steps: int = 5, start_val: int = 0, **kwargs, @@ -1287,7 +1278,7 @@ def __init__( torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1600,7 +1591,7 @@ def __init__( elif start_val.numel() <= 1: self.start_val = start_val.expand_as(self.count) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1816,7 +1807,7 @@ def _step( return td - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) @@ -2047,7 +2038,7 @@ def _step( assert td.batch_size == self.batch_size return td - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) @@ -2084,7 +2075,7 @@ def _step( data.update(self._saved_full_reward_spec.zero()) return data - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): return seed @@ -2210,7 +2201,7 @@ def _step( reward = self.full_reward_spec.zero() return observation.update(done).update(reward) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self.manual_seed = seed return seed @@ -2280,7 +2271,7 @@ def _step( ), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... @@ -2328,7 +2319,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: def get_random_string(self): return get_random_string(self.min_size, self.max_size) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): random.seed(seed) torch.manual_seed(0) return seed @@ -2356,7 +2347,7 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: .update(self.full_reward_spec.zero()) ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... diff --git a/test/opengl_rendering.py b/test/opengl_rendering.py index 0e2f86294c1..1ec4c248841 100644 --- a/test/opengl_rendering.py +++ b/test/opengl_rendering.py @@ -23,7 +23,6 @@ # OpenGL context is available here. """ -from __future__ import print_function # pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports diff --git a/test/test_collector.py b/test/test_collector.py index f3935d1086b..523f851b37d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -10,29 +10,20 @@ import gc import os import subprocess - import sys -from typing import Optional from unittest.mock import patch import numpy as np import pytest import torch from packaging import version -from tensordict import ( - assert_allclose_td, - LazyStackedTensorDict, - NonTensorData, - TensorDict, - TensorDictBase, -) +from tensordict import (LazyStackedTensorDict, NonTensorData, TensorDict, TensorDictBase, assert_allclose_td) from tensordict.nn import ( CudaGraphModule, TensorDictModule, TensorDictModuleBase, TensorDictSequential, ) - from torch import nn from torchrl._utils import ( _make_ordinal_device, @@ -41,12 +32,8 @@ prod, seed_generator, ) -from torchrl.collectors import aSyncDataCollector, SyncDataCollector -from torchrl.collectors.collectors import ( - _Interruptor, - MultiaSyncDataCollector, - MultiSyncDataCollector, -) +from torchrl.collectors import SyncDataCollector, aSyncDataCollector +from torchrl.collectors.collectors import (MultiSyncDataCollector, MultiaSyncDataCollector, _Interruptor) from torchrl.collectors.utils import split_trajectories from torchrl.data import ( Composite, @@ -67,14 +54,9 @@ StepCounter, Transform, ) -from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend +from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend, set_gym_backend from torchrl.envs.transforms import TransformedEnv, VecNorm -from torchrl.envs.utils import ( - _aggregate_end_of_traj, - check_env_specs, - PARTIAL_MISSING_ERR, - RandomPolicy, -) +from torchrl.envs.utils import (PARTIAL_MISSING_ERR, RandomPolicy, _aggregate_end_of_traj, check_env_specs) from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule if os.getenv("PYTORCH_TEST_FBCODE"): @@ -1820,7 +1802,7 @@ def _step( def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: return self.full_done_specs.zeros().update(self.observation_spec.zeros()) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): return seed @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") diff --git a/test/test_loggers.py b/test/test_loggers.py index 6b659fe7245..cf438feb9f1 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -14,12 +14,11 @@ import pytest import torch from tensordict import MemoryMappedTensor - -from torchrl.envs import check_env_specs, GymEnv, ParallelEnv +from torchrl.envs import GymEnv, ParallelEnv, check_env_specs from torchrl.record.loggers.csv import CSVLogger -from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger -from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger -from torchrl.record.loggers.wandb import _has_wandb, WandbLogger +from torchrl.record.loggers.mlflow import MLFlowLogger, _has_mlflow, _has_tv +from torchrl.record.loggers.tensorboard import TensorboardLogger, _has_tb +from torchrl.record.loggers.wandb import WandbLogger, _has_wandb from torchrl.record.recorder import PixelRenderTransform, VideoRecorder if _has_tv: @@ -154,7 +153,7 @@ def test_log_scalar(self, steps, tmpdir): step=steps[i] if steps else None, ) - with open(os.path.join(tmpdir, exp_name, "scalars", "foo.csv"), "r") as file: + with open(os.path.join(tmpdir, exp_name, "scalars", "foo.csv")) as file: for i, row in enumerate(file.readlines()): step = steps[i] if steps else i assert row == f"{step},{values[i].item()}\n" @@ -239,7 +238,7 @@ def test_log_config(self, tmpdir, config): logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) logger.log_hparams(cfg=config) - with open(os.path.join(tmpdir, exp_name, "texts", "hparams0.txt"), "r") as file: + with open(os.path.join(tmpdir, exp_name, "texts", "hparams0.txt")) as file: txt = "\n".join([f"{k}: {val}" for k, val in sorted(config.items())]) text = "".join(file.readlines()) assert text == txt diff --git a/test/test_specs.py b/test/test_specs.py index f523dde54f7..142aa128c6e 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -21,9 +21,7 @@ ) from tensordict.utils import _unravel_key_to_tuple, set_capture_non_tensor_stack from torchrl._utils import _make_ordinal_device - from torchrl.data.tensor_specs import ( - _keys_to_empty_composite_spec, Binary, BinaryDiscreteTensorSpec, Bounded, @@ -49,6 +47,7 @@ UnboundedContinuousTensorSpec, UnboundedDiscrete, UnboundedDiscreteTensorSpec, + _keys_to_empty_composite_spec, ) from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec @@ -3573,8 +3572,8 @@ def test_valid_indexing(spec_class): assert spec_3d[1:, range(3)].shape == torch.Size([4, 3, 4]) assert spec_3d[[[[[0, 1]]]], [[0]]].shape == torch.Size([1, 1, 1, 2, 4]) assert spec_3d[0, [[[[0, 1]]]]].shape == torch.Size([1, 1, 1, 2, 4]) - assert spec_3d[0, ((((0, 1))))].shape == torch.Size([2, 4]) - assert spec_3d[((((0, 1)))), [0, 2]].shape == torch.Size([2, 4]) + assert spec_3d[0, ((0, 1))].shape == torch.Size([2, 4]) + assert spec_3d[((0, 1)), [0, 2]].shape == torch.Size([2, 4]) assert spec_4d[2:, [[[0, 1]]], :3].shape == torch.Size([3, 1, 1, 2, 3, 6]) assert spec_5d[2:, [[[0, 1]]], [[0, 1]], :3].shape == torch.Size([3, 1, 1, 2, 3, 7]) assert spec_5d[2:, [[[0, 1]]], 0, :3].shape == torch.Size([3, 1, 1, 2, 3, 7]) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 1d541750ec2..f614b7f04a0 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -5,12 +5,9 @@ from __future__ import annotations import collections - import functools import inspect - import logging - import math import os import pickle @@ -24,15 +21,14 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union +from typing import Any, Callable, Tuple, TypeVar, cast import numpy as np import torch from packaging.version import parse from tensordict import unravel_key - from tensordict.utils import NestedKey -from torch import multiprocessing as mp, Tensor +from torch import Tensor, multiprocessing as mp try: from torch.compiler import is_compiling @@ -345,7 +341,7 @@ class implement_for: def __init__( self, - module_name: Union[str, Callable], + module_name: str | Callable, from_version: str = None, to_version: str = None, *, @@ -419,7 +415,7 @@ def module_set(self): setattr(cls, self.fn.__name__, self.fn) @classmethod - def import_module(cls, module_name: Union[Callable, str]) -> str: + def import_module(cls, module_name: Callable | str) -> str: """Imports module and returns its version.""" if not callable(module_name): module = cls._cache_modules.get(module_name, None) @@ -515,7 +511,7 @@ def unsupported(*args, **kwargs): return unsupported @classmethod - def reset(cls, setters_dict: Dict[str, implement_for] = None): + def reset(cls, setters_dict: dict[str, implement_for] = None): """Resets the setters in setter_dict. ``setter_dict`` is a copy of implementations. We just need to iterate through its @@ -880,7 +876,7 @@ def set_mode(self, type: Any | None) -> None: def _standardize( input: Tensor, - exclude_dims: Tuple[int] = (), + exclude_dims: tuple[int] = (), mean: Tensor | None = None, std: Tensor | None = None, eps: float | None = None, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2184ae9e19c..d18e2d9a25b 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -7,9 +7,7 @@ import _pickle import abc import collections - import contextlib - import functools import os import queue @@ -17,16 +15,15 @@ import time import typing import warnings -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager from queue import Empty from textwrap import indent -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Iterator, Sequence import numpy as np - import torch import torch.nn as nn from tensordict import ( @@ -41,32 +38,31 @@ from torch import multiprocessing as mp from torch.nn import Parameter from torch.utils.data import IterableDataset - from torchrl._utils import ( + RL_WARNINGS, + VERBOSE, + _ProcessNoWarn, _check_for_faulty_process, _ends_with, _make_ordinal_device, - _ProcessNoWarn, _replace_last, accept_remote_rref_udf_invocation, compile_with_warmup, logger as torchrl_logger, prod, - RL_WARNINGS, - VERBOSE, ) from torchrl.collectors.utils import split_trajectories from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, EnvBase +from torchrl.envs.common import EnvBase, _do_nothing from torchrl.envs.env_creator import EnvCreator from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( - _aggregate_end_of_traj, - _make_compatible_policy, ExplorationType, RandomPolicy, + _aggregate_end_of_traj, + _make_compatible_policy, set_exploration_type, ) @@ -162,7 +158,7 @@ def _get_policy_and_device( policy_device: Any = NO_DEFAULT, env_maker: Any | None = None, env_maker_kwargs: dict | None = None, - ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: """Util method to get a policy and its device given the collector __init__ inputs. We want to copy the policy and then move the data there, not call policy.to(device). @@ -245,7 +241,7 @@ def map_weight( return policy, get_original_weights def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: """Updates the policy weights if the policy of the data collector and the trained policy live on different devices. @@ -513,15 +509,11 @@ class SyncDataCollector(DataCollectorBase): def __init__( self, - create_env_fn: Union[ - EnvBase, "EnvCreator", Sequence[Callable[[], EnvBase]] # noqa: F821 - ], # noqa: F821 - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ] = None, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, frames_per_batch: int, total_frames: int = -1, @@ -543,8 +535,8 @@ def __init__( use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, trust_policy: bool = None, - compile_policy: bool | Dict[str, Any] | None = None, - cudagraph_policy: bool | Dict[str, Any] | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, **kwargs, ): @@ -990,7 +982,7 @@ def next(self): # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -1617,25 +1609,21 @@ class _MultiDataCollector(DataCollectorBase): def __init__( self, create_env_fn: Sequence[Callable[[], EnvBase]], - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ] = None, + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, frames_per_batch: int, - total_frames: Optional[int] = -1, + total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Optional[Sequence[dict]] = None, + create_env_kwargs: Sequence[dict] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, @@ -1648,8 +1636,8 @@ def __init__( replay_buffer: ReplayBuffer | None = None, replay_buffer_chunk: bool = True, trust_policy: bool = None, - compile_policy: bool | Dict[str, Any] | None = None, - cudagraph_policy: bool | Dict[str, Any] | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, ): self.closed = True @@ -2086,7 +2074,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: self.reset() return seed - def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: """Resets the environments to a new initial state. Args: @@ -2282,7 +2270,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -2646,7 +2634,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -2654,7 +2642,7 @@ def update_policy_weights_( def frames_per_batch_worker(self): return self.requested_frames_per_batch - def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers if self.replay_buffer is not None: @@ -2745,7 +2733,7 @@ def _shutdown_main(self) -> None: del self.out_tensordicts return super()._shutdown_main() - def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: super().reset(reset_idx) if self.queue_out.full(): time.sleep(_TIMEOUT) # wait until queue is empty @@ -2900,25 +2888,20 @@ class aSyncDataCollector(MultiaSyncDataCollector): def __init__( self, create_env_fn: Callable[[], EnvBase], - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ], + policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]), *, frames_per_batch: int, - total_frames: Optional[int] = -1, + total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Optional[Sequence[dict]] = None, + create_env_kwargs: Sequence[dict] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, @@ -2977,15 +2960,15 @@ def _main_async_collector( pipe_parent: connection.Connection, pipe_child: connection.Connection, queue_out: queues.Queue, - create_env_fn: Union[EnvBase, "EnvCreator", Callable[[], EnvBase]], # noqa: F821 - create_env_kwargs: Dict[str, Any], + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int, frames_per_batch: int, reset_at_each_iter: bool, - storing_device: Optional[Union[torch.device, str, int]], - env_device: Optional[Union[torch.device, str, int]], - policy_device: Optional[Union[torch.device, str, int]], + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, idx: int = 0, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 3491c48138c..df0edc5d5d3 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -11,17 +11,16 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import Callable, List, OrderedDict, Type +from typing import Callable, List, OrderedDict import torch.cuda from tensordict import TensorDict from torch import nn - -from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl._utils import VERBOSE, _ProcessNoWarn, logger as torchrl_logger from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DataCollectorBase, DEFAULT_EXPLORATION_TYPE, + DataCollectorBase, MultiSyncDataCollector, SyncDataCollector, ) @@ -410,17 +409,17 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - collector_class: Type = SyncDataCollector, + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa + collector_class: type = SyncDataCollector, collector_kwargs: dict = None, num_workers_per_collector: int = 1, sync: bool = False, @@ -527,19 +526,19 @@ def __init__( self._make_container() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter @@ -899,7 +898,7 @@ def update_policy_weights_(self, worker_rank=None) -> None: def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 - self._store.set(f"NODE_{rank}_in", f"seeding_{seed}".encode("utf-8")) + self._store.set(f"NODE_{rank}_in", f"seeding_{seed}".encode()) status = self._store.get(f"NODE_{rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'seeded' but got status {status}.") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 715e41f50fd..46947f51d5d 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -6,17 +6,16 @@ from __future__ import annotations import warnings -from typing import Callable, Dict, Iterator, List, OrderedDict, Union +from typing import Callable, Iterator, List, OrderedDict import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase - from torchrl._utils import logger as torchrl_logger from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DataCollectorBase, DEFAULT_EXPLORATION_TYPE, + DataCollectorBase, MultiSyncDataCollector, SyncDataCollector, ) @@ -24,7 +23,6 @@ from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator - RAY_ERR = None try: import ray @@ -289,15 +287,15 @@ class RayCollector(DataCollectorBase): def __init__( self, - create_env_fn: Union[Callable, EnvBase, List[Callable], List[EnvBase]], + create_env_fn: Callable | EnvBase | list[Callable] | list[EnvBase], policy: Callable[[TensorDict], TensorDict], *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj=-1, init_random_frames=-1, reset_at_each_iter=False, @@ -305,11 +303,11 @@ def __init__( split_trajs=False, exploration_type=DEFAULT_EXPLORATION_TYPE, collector_class: Callable[[TensorDict], TensorDict] = SyncDataCollector, - collector_kwargs: Union[Dict, List[Dict]] = None, + collector_kwargs: dict | list[dict] = None, num_workers_per_collector: int = 1, sync: bool = False, - ray_init_config: Dict = None, - remote_configs: Union[Dict, List[Dict]] = None, + ray_init_config: dict = None, + remote_configs: dict | list[dict] = None, num_collectors: int = None, update_after_each_batch=False, max_weight_update_interval=-1, @@ -483,19 +481,19 @@ def num_workers(self): return self.num_collectors @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter @@ -713,13 +711,13 @@ def update_policy_weights_(self, worker_rank=None) -> None: ) self._batches_since_weight_update[worker_rank - 1] = 0 - def set_seed(self, seed: int, static_seed: bool = False) -> List[int]: + def set_seed(self, seed: int, static_seed: bool = False) -> list[int]: """Calls parent method for each remote collector iteratively and returns final seed.""" for collector in self.remote_collectors(): seed = ray.get(object_refs=collector.set_seed.remote(seed, static_seed)) return seed - def state_dict(self) -> List[OrderedDict]: + def state_dict(self) -> list[OrderedDict]: """Calls parent method for each remote collector and returns a list of results.""" futures = [ collector.state_dict.remote() for collector in self.remote_collectors() @@ -727,9 +725,7 @@ def state_dict(self) -> List[OrderedDict]: results = ray.get(object_refs=futures) return results - def load_state_dict( - self, state_dict: Union[OrderedDict, List[OrderedDict]] - ) -> None: + def load_state_dict(self, state_dict: OrderedDict | list[OrderedDict]) -> None: """Calls parent method for each remote collector.""" if isinstance(state_dict, OrderedDict): state_dicts = [state_dict] diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 94421ce8ca3..2575469a7bb 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -15,7 +15,6 @@ from typing import Callable, List, OrderedDict from torchrl._utils import logger as torchrl_logger - from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -266,16 +265,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -386,19 +385,19 @@ def __init__( self._init() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 6aa66dfbdd2..17c54279e4b 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -16,13 +16,11 @@ import torch.cuda from tensordict import TensorDict from torch import nn - -from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE - +from torchrl._utils import VERBOSE, _ProcessNoWarn, logger as torchrl_logger from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DataCollectorBase, DEFAULT_EXPLORATION_TYPE, + DataCollectorBase, MultiSyncDataCollector, SyncDataCollector, ) @@ -276,16 +274,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -384,19 +382,19 @@ def __init__( self._make_container() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 2a1ec90feb6..775751d73af 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -22,7 +22,6 @@ from torch import multiprocessing as mp from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay - from torchrl.data.replay_buffers.samplers import ( SamplerWithoutReplacement, SliceSampler, @@ -403,7 +402,7 @@ def __init__( download: bool | str = True, sampler=None, writer=None, - transform: "Transform" | None = None, # noqa: F821 + transform: Transform | None = None, # noqa: F821 num_procs: int = 0, num_slices: int | None = None, slice_len: int | None = None, @@ -493,7 +492,7 @@ def _is_downloaded(self): if os.path.exists(self.dataset_path / "meta.json"): return True if os.path.exists(self.dataset_path / "processed.json"): - with open(self.dataset_path / "processed.json", "r") as jsonfile: + with open(self.dataset_path / "processed.json") as jsonfile: return json.load(jsonfile).get("processed", False) == self._max_runs return False @@ -514,7 +513,7 @@ def _download_and_preproc(self): command, shell=True, capture_output=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) files = [ - file.decode("utf-8").replace("$", "\$") # noqa: W605 + file.decode("utf-8").replace("$", r"\$") # noqa: W605 for file in output.stdout.splitlines() if file.endswith(b".gz") ] @@ -819,7 +818,7 @@ def _load_split(self, path): def _proc_td(self, td, index): td_data = td.get("data") - obs_ = td_data.get(("observation"))[index + 1] + obs_ = td_data.get("observation")[index + 1] done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() if done.ndim and done.any(): obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 4ed9dcabada..e8b47619bce 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -10,22 +10,16 @@ import tempfile import urllib import warnings - from pathlib import Path from typing import Callable import numpy as np - import torch - -from tensordict import make_tensordict, PersistentTensorDict, TensorDict - +from tensordict import PersistentTensorDict, TensorDict, make_tensordict from torchrl._utils import logger as torchrl_logger - from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS - from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -145,7 +139,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, from_env: bool = False, use_truncated_as_done: bool = True, @@ -459,7 +453,7 @@ def _download_dataset_from_url(dataset_url, dataset_path): torchrl_logger.info(f"Downloading dataset: {dataset_url} to {dataset_filepath}") urllib.request.urlretrieve(dataset_url, dataset_filepath) if not os.path.exists(dataset_filepath): - raise IOError("Failed to download dataset from %s" % dataset_url) + raise OSError("Failed to download dataset from %s" % dataset_url) return dataset_filepath diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 53efeb54898..ebe9d032c0d 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -9,7 +9,6 @@ import os.path import shutil import tempfile - from collections import defaultdict from contextlib import nullcontext from dataclasses import asdict @@ -17,7 +16,6 @@ from typing import Callable import torch - from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay @@ -167,7 +165,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, ): self.dataset_id = dataset_id @@ -381,7 +379,7 @@ def _load(self): return TensorDict.load_memmap(self.data_path) def _load_and_proc_metadata(self): - with open(self.metadata_path, "r") as file: + with open(self.metadata_path) as file: self.metadata = json.load(file) self.metadata["observation_space"] = _proc_spec( self.metadata["observation_space"] diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index c420eb93ad3..344cf43bf3f 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -11,7 +11,6 @@ import numpy as np from tensordict import TensorDict from torchrl.data.datasets.common import BaseDatasetExperienceReplay - from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import ( Sampler, @@ -67,7 +66,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 ): if sampler is None: diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 94ad82c7a88..cf81f44cf56 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -11,13 +11,11 @@ import shutil import tempfile from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable import torch - -from tensordict import make_tensordict, NonTensorData, pad, TensorDict +from tensordict import NonTensorData, TensorDict, make_tensordict, pad from tensordict.utils import _is_non_tensor - from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer @@ -26,7 +24,7 @@ SliceSampler, SliceSamplerWithoutReplacement, ) -from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage +from torchrl.data.replay_buffers.storages import Storage, TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer _has_datasets = importlib.util.find_spec("datasets", None) is not None @@ -313,7 +311,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, strict_length: bool = True, ): @@ -656,7 +654,7 @@ def dumps(self, path): state_dict = self.state_dict() json.dump(state_dict, path / "state_dict.json") - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "repo": self.repo, "split": self.split, @@ -674,7 +672,7 @@ def loads(self, path): state_dict = json.load(path / "state_dict.json") self.load_state_dict(state_dict) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for key, val in state_dict.items(): setattr(self, key, val) self._init() @@ -722,7 +720,7 @@ class _StreamingSampler(Sampler): def __init__(self): ... - def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]: return range(batch_size), {} def _empty(self): @@ -734,10 +732,10 @@ def dumps(self, path): def loads(self, path): ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 1a83c302860..9eccf07a286 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -8,13 +8,11 @@ import os.path import shutil import tempfile - from contextlib import nullcontext from pathlib import Path from typing import Callable import torch - from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import ( KeyDependentDefaultDict, @@ -162,7 +160,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, **env_kwargs, ): diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index daf1c85e3d7..0c851e2408c 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -5,7 +5,6 @@ from __future__ import annotations import functools - import importlib import json import os @@ -14,21 +13,18 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Callable, List +from typing import Callable import numpy as np - import torch from tensordict import PersistentTensorDict, TensorDict from torch import multiprocessing as mp - from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer - from torchrl.envs.transforms import Compose, Resize, ToTensorImage from torchrl.envs.utils import _classproperty @@ -172,10 +168,10 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, totensor: bool = True, - image_size: int | List[int] | None = None, + image_size: int | list[int] | None = None, num_workers: int = 0, **env_kwargs, ): @@ -388,7 +384,7 @@ def _available_datasets(cls): return [str(path)[6:] for path in sibs] except Exception: # return the default datasets - with open(THIS_DIR / "vd4rl.json", "r") as file: + with open(THIS_DIR / "vd4rl.json") as file: return json.load(file) def _make_split(self): diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index a3ae9ec1ae9..446e71f56fb 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Callable, List +from typing import Callable import torch from torch.nn import Module @@ -103,7 +103,7 @@ def __init__(self, as_tensor: bool = True): super().__init__() self.as_tensor = as_tensor - def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: + def forward(self, x: torch.Tensor) -> torch.Tensor | list[bytes]: hash_values = [] if x.dtype in (torch.bfloat16,): x = x.to(torch.float16) diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index 3eca179cf56..d53e2087760 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -5,7 +5,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, TypeVar +from typing import Any, Callable, List, Mapping, TypeVar import torch import torch.nn as nn @@ -39,7 +39,7 @@ def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor: ) return torch.tensor(result, device=key.device, dtype=key.dtype) - def state_dict(self) -> Dict[str, torch.Tensor]: + def state_dict(self) -> dict[str, torch.Tensor]: values = torch.tensor(self._index_to_index.values()) keys = torch.tensor(self._index_to_index.keys()) return {"keys": keys, "values": values} @@ -111,11 +111,11 @@ class QueryModule(TensorDictModuleBase): def __init__( self, - in_keys: List[NestedKey], + in_keys: list[NestedKey], index_key: NestedKey = "_index", hash_key: NestedKey = "_hash", *, - hash_module: Callable[[Any], int] | List[Callable[[Any], int]] | None = None, + hash_module: Callable[[Any], int] | list[Callable[[Any], int]] | None = None, hash_to_int: Callable[[int], int] | None = None, aggregator: Callable[[Any], int] = None, clone: bool = False, diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index 34d4bb8d0fa..f510f1d5dbe 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -10,15 +10,11 @@ from typing import Any, Callable, Dict, Generic, List, TypeVar import torch -from tensordict import is_tensor_collection, NestedKey, TensorDictBase +from tensordict import NestedKey, TensorDictBase, is_tensor_collection from tensordict.nn.common import TensorDictModuleBase from torchrl.data.map.hash import RandomProjectionHash, SipHash from torchrl.data.map.query import QueryModule -from torchrl.data.replay_buffers.storages import ( - _get_default_collate, - LazyTensorStorage, - TensorStorage, -) +from torchrl.data.replay_buffers.storages import (LazyTensorStorage, TensorStorage, _get_default_collate) K = TypeVar("K") V = TypeVar("V") @@ -117,9 +113,9 @@ def __init__( self, *, query_module: QueryModule, - storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + storage: dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], collate_fn: Callable[[Any], Any] | None = None, - out_keys: List[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, ): super().__init__() @@ -143,7 +139,7 @@ def max_size(self): return self.storage.max_size @property - def out_keys(self) -> List[NestedKey]: + def out_keys(self) -> list[NestedKey]: out_keys = self.__dict__.get("_out_keys_and_lazy") if out_keys is not None: return out_keys[0] @@ -173,8 +169,8 @@ def from_tensordict_pair( cls, source, dest, - in_keys: List[NestedKey], - out_keys: List[NestedKey] | None = None, + in_keys: list[NestedKey], + out_keys: list[NestedKey] | None = None, max_size: int = 1000, storage_constructor: type | None = None, hash_module: Callable | None = None, diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 684c4f9901b..ca8fc683d7f 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -6,23 +6,14 @@ import weakref from collections import deque - -from typing import Any, Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Dict, Literal, Tuple import torch -from tensordict import ( - merge_tensordicts, - NestedKey, - TensorClass, - TensorDict, - TensorDictBase, - unravel_key, -) +from tensordict import (NestedKey, TensorClass, TensorDict, TensorDictBase, merge_tensordicts, unravel_key) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage from torchrl.data.tensor_specs import Composite - from torchrl.envs.common import EnvBase @@ -88,10 +79,10 @@ class Tree(TensorClass["nocast"]): node_data: TensorDict | None = None # Stack of subtrees. A subtree is produced when an action is taken. - subtree: "Tree" = None + subtree: Tree = None # weakrefs to the parent(s) of the node - _parent: weakref.ref | List[weakref.ref] | None = None + _parent: weakref.ref | list[weakref.ref] | None = None # Specs: contains information such as action or observation keys and spaces. # If present, they should be structured like env specs are: @@ -389,7 +380,7 @@ def __contains__(self, other: Tree) -> bool: def vertices( self, *, key_type: Literal["id", "hash", "path"] = "hash" - ) -> Dict[int | Tuple[int], Tree]: + ) -> dict[int | tuple[int], Tree]: """Returns a map containing the vertices of the Tree. Keyword args: @@ -463,7 +454,7 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: } ) - def edges(self) -> List[Tuple[int, int]]: + def edges(self) -> list[tuple[int, int]]: """Retrieves a list of edges in the tree. Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. @@ -530,7 +521,7 @@ def max_length(self): return lengths[0] return max(*lengths) - def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + def rollout_from_path(self, path: tuple[int]) -> TensorDictBase | None: """Retrieves the rollout data along a given path in the tree. The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. @@ -557,7 +548,7 @@ def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: return torch.cat(rollouts, dim=-1) @staticmethod - def _label(info: List[str], tree: "Tree", root=False): + def _label(info: list[str], tree: Tree, root=False): labels = [] for key in info: if key == "hash": @@ -577,7 +568,7 @@ def plot( self: Tree, backend: str = "plotly", figure: str = "tree", - info: List[str] = None, + info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None, ): """Plots a visualization of the tree using the specified backend and figure type. @@ -811,11 +802,11 @@ def __init__( data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, - done_keys: List[NestedKey] | None = None, - reward_keys: List[NestedKey] = None, - observation_keys: List[NestedKey] = None, - action_keys: List[NestedKey] = None, - excluded_keys: List[NestedKey] = None, + done_keys: list[NestedKey] | None = None, + reward_keys: list[NestedKey] = None, + observation_keys: list[NestedKey] = None, + action_keys: list[NestedKey] = None, + excluded_keys: list[NestedKey] = None, consolidated: bool | None = None, ): @@ -856,7 +847,7 @@ def __init__( self.consolidated = consolidated @property - def done_keys(self) -> List[NestedKey]: + def done_keys(self) -> list[NestedKey]: """Done Keys. Returns the keys used to indicate that an episode has ended. @@ -877,7 +868,7 @@ def done_keys(self, value): self._done_keys = _make_list_of_nestedkeys(value, "done_keys") @property - def reward_keys(self) -> List[NestedKey]: + def reward_keys(self) -> list[NestedKey]: """Reward Keys. Returns the keys used to retrieve rewards from the environment's output. @@ -897,7 +888,7 @@ def reward_keys(self, value): self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys") @property - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: """Action Keys. Returns the keys used to retrieve actions from the environment's input. @@ -917,7 +908,7 @@ def action_keys(self, value): self._action_keys = _make_list_of_nestedkeys(value, "action_keys") @property - def observation_keys(self) -> List[NestedKey]: + def observation_keys(self) -> list[NestedKey]: """Observation Keys. Returns the keys used to retrieve observations from the environment's output. @@ -936,7 +927,7 @@ def observation_keys(self, value): self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys") @property - def excluded_keys(self) -> List[NestedKey] | None: + def excluded_keys(self) -> list[NestedKey] | None: return self._excluded_keys @excluded_keys.setter @@ -1223,7 +1214,7 @@ def _make_local_tree( root: TensorDictBase, index: torch.Tensor | None = None, compact: bool = True, - ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: + ) -> tuple[Tree, torch.Tensor | None, torch.Tensor | None]: root = root.select(*self.node_map.in_keys) node_meta = None if root in self.node_map: @@ -1422,7 +1413,7 @@ def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict( return tree.to_string(node_format_fn) -def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: +def _make_list_of_nestedkeys(obj: Any, attr: str) -> list[NestedKey]: if obj is None: return obj if isinstance(obj, (str, tuple)): diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index d9588d79905..43bd8ea2832 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Callable, List +from typing import Callable from tensordict import NestedKey def _plot_plotly_tree( - tree: "Tree", make_labels: Callable[[Tree], str] | None = None # noqa: F821 + tree: Tree, make_labels: Callable[[Tree], str] | None = None # noqa: F821 ): import plotly.graph_objects as go from igraph import Graph @@ -78,7 +78,7 @@ def make_labels(tree, path, *args, **kwargs): fig.show() -def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 +def _plot_plotly_box(tree: Tree, info: list[NestedKey] = None): # noqa: F821 import plotly.graph_objects as go if info is None: @@ -89,7 +89,7 @@ def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 _tree = tree - def extend(tree: "Tree", parent): # noqa: F821 + def extend(tree: Tree, parent): # noqa: F821 children = tree.subtree if children is None: return diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 6b74834385e..4e06756efc2 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -11,23 +11,17 @@ import numpy as np import torch -from tensordict import ( - is_tensor_collection, - NonTensorData, - PersistentTensorDict, - TensorDict, -) +from tensordict import (NonTensorData, PersistentTensorDict, TensorDict, is_tensor_collection) from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE - from torchrl.data.replay_buffers.utils import ( - _save_pytree, Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested, + _save_pytree, ) @@ -120,7 +114,7 @@ def dumps(self, storage, path): ) def loads(self, storage, path): - with open(path / "storage_metadata.json", "r") as file: + with open(path / "storage_metadata.json") as file: metadata = json.load(file) is_pytree = metadata["is_pytree"] _len = metadata["len"] diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 6e8e879b512..20e029fc535 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -13,10 +13,9 @@ import warnings from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import numpy as np - import torch try: @@ -222,10 +221,10 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, - checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 + checkpointer: StorageCheckpointerBase | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, compilable: bool = None, @@ -460,7 +459,7 @@ def __setitem__(self, index, value) -> None: self._storage[index] = value return - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_storage": self._storage.state_dict(), "_sampler": self._sampler.state_dict(), @@ -472,7 +471,7 @@ def state_dict(self) -> Dict[str, Any]: else None, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) self._sampler.load_state_dict(state_dict["_sampler"]) self._writer.load_state_dict(state_dict["_writer"]) @@ -564,7 +563,7 @@ def loads(self, path): # fall back on state_dict for transforms if (path / "transform.t").exists(): self._transform.load_state_dict(torch.load(path / "transform.t")) - with open(path / "buffer_metadata.json", "r") as file: + with open(path / "buffer_metadata.json") as file: metadata = json.load(file) self._batch_size = metadata["batch_size"] @@ -662,8 +661,8 @@ def extend(self, data: Sequence) -> torch.Tensor: def update_priority( self, - index: Union[int, torch.Tensor, Tuple[torch.Tensor]], - priority: Union[int, torch.Tensor], + index: int | torch.Tensor | tuple[torch.Tensor], + priority: int | torch.Tensor, ) -> None: if isinstance(index, tuple): index = torch.stack(index, -1) @@ -675,7 +674,7 @@ def update_priority( self._sampler.update_priority(index, priority, storage=self.storage) @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: + def _sample(self, batch_size: int) -> tuple[Any, dict]: with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index @@ -755,11 +754,11 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An return out, info return result[0] - def mark_update(self, index: Union[int, torch.Tensor]) -> None: + def mark_update(self, index: int | torch.Tensor) -> None: self._sampler.mark_update(index, storage=self._storage) def append_transform( - self, transform: "Transform", *, invert: bool = False # noqa-F821 + self, transform: Transform, *, invert: bool = False # noqa-F821 ) -> ReplayBuffer: # noqa: D417 """Appends transform at the end. @@ -796,7 +795,7 @@ def append_transform( def insert_transform( self, index: int, - transform: "Transform", # noqa-F821 + transform: Transform, # noqa-F821 *, invert: bool = False, ) -> ReplayBuffer: # noqa: D417 @@ -832,7 +831,7 @@ def __iter__(self): ): yield self.sample() - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() if self._rng is not None: rng_state = TensorDict( @@ -848,7 +847,7 @@ def __getstate__(self) -> Dict[str, Any]: state["_futures_lock_placeholder"] = None return state - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: dict[str, Any]): rngstate = None if "_rng" in state: rngstate = state["_rng"] @@ -1008,14 +1007,14 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype) - super(PrioritizedReplayBuffer, self).__init__( + super().__init__( storage=storage, sampler=sampler, collate_fn=collate_fn, @@ -1355,7 +1354,7 @@ def sample( return data @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: + def _sample(self, batch_size: int) -> tuple[Any, dict]: with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index @@ -1523,7 +1522,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 reduction: str = "max", batch_size: int | None = None, dim_extend: int | None = None, @@ -1536,7 +1535,7 @@ def __init__( sampler = PrioritizedSampler( storage.max_size, alpha, beta, eps, reduction=reduction ) - super(TensorDictPrioritizedReplayBuffer, self).__init__( + super().__init__( priority_key=priority_key, storage=storage, sampler=sampler, @@ -1572,11 +1571,11 @@ def sample( def add(self, data: TensorDictBase) -> int: return super().add(data) - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + def extend(self, tensordicts: list | TensorDictBase) -> torch.Tensor: return super().extend(tensordicts) def update_priority( - self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] + self, index: int | torch.Tensor, priority: int | torch.Tensor ) -> None: return super().update_priority(index, priority) @@ -1593,7 +1592,7 @@ def __init__(self, device: DEVICE_TYPING | None = None): ) -def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: +def stack_tensors(list_of_tensor_iterators: list) -> tuple[torch.Tensor]: """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together. Args: @@ -1765,10 +1764,10 @@ def __init__( storages: StorageEnsemble | None = None, samplers: SamplerEnsemble | None = None, writers: WriterEnsemble | None = None, - transform: "Transform" | None = None, # noqa: F821 + transform: Transform | None = None, # noqa: F821 batch_size: int | None = None, collate_fn: Callable | None = None, - collate_fns: List[Callable] | None = None, + collate_fns: list[Callable] | None = None, p: Tensor = None, sample_from_all: bool = False, num_buffer_sampled: int | None = None, @@ -1849,7 +1848,7 @@ def _collate_fn(self, value): _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead." def __getitem__( - self, index: Union[int, torch.Tensor, Tuple, np.ndarray, List, slice, Ellipsis] + self, index: int | torch.Tensor | tuple | np.ndarray | list | slice | Ellipsis ) -> Any: # accepts inputs: # (int | 1d tensor | 1d list | 1d array | slice | ellipsis | range, int | tensor | list | array | slice | ellipsis | range) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 911280eb667..19db2fa9431 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -12,16 +12,13 @@ from copy import copy, deepcopy from multiprocessing.context import get_spawning_popen from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import MemoryMappedTensor, TensorDict from tensordict.utils import NestedKey - from torchrl._extension import EXTENSION_WARNING - from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index @@ -50,7 +47,7 @@ class Sampler(ABC): _rng: torch.Generator | None = None @abstractmethod - def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]: ... def add(self, index: int) -> None: @@ -61,8 +58,8 @@ def extend(self, index: torch.Tensor) -> None: def update_priority( self, - index: Union[int, torch.Tensor], - priority: Union[float, torch.Tensor], + index: int | torch.Tensor, + priority: float | torch.Tensor, *, storage: Storage | None = None, ) -> dict | None: @@ -72,7 +69,7 @@ def update_priority( return def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: return @@ -81,11 +78,11 @@ def default_priority(self) -> float: return 1.0 @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @property @@ -123,7 +120,7 @@ class RandomSampler(Sampler): """ - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: if len(storage) == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) index = storage._rand_given_ndim(batch_size) @@ -140,10 +137,10 @@ def loads(self, path): # no op ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return @@ -232,7 +229,7 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int - ) -> Tuple[Any, dict]: # noqa: F811 + ) -> tuple[Any, dict]: # noqa: F811 len_storage = self._storage_len(storage) if len_storage == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) @@ -269,7 +266,7 @@ def _empty(self): self.len_storage = 0 self._ran_out = False - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return OrderedDict( len_storage=self.len_storage, _sample_list=self._sample_list, @@ -277,7 +274,7 @@ def state_dict(self) -> Dict[str, Any]: _ran_out=self._ran_out, ) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.len_storage = state_dict["len_storage"] self._sample_list = state_dict["_sample_list"] self.drop_last = state_dict["drop_last"] @@ -542,8 +539,8 @@ def extend(self, index: torch.Tensor | tuple) -> None: @torch.no_grad() def update_priority( self, - index: Union[int, torch.Tensor], - priority: Union[float, torch.Tensor], + index: int | torch.Tensor, + priority: float | torch.Tensor, *, storage: TensorStorage | None = None, ) -> None: # noqa: D417 @@ -626,11 +623,11 @@ def update_priority( self._max_priority = (maxval, maxidx) def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: self.update_priority(index, self.default_priority, storage=storage) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_alpha": self._alpha, "_beta": self._beta, @@ -640,7 +637,7 @@ def state_dict(self) -> Dict[str, Any]: "_min_tree": deepcopy(self._min_tree), } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._alpha = state_dict["_alpha"] self._beta = state_dict["_beta"] self._eps = state_dict["_eps"] @@ -693,7 +690,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "sampler_metadata.json", "r") as file: + with open(path / "sampler_metadata.json") as file: metadata = json.load(file) self._alpha = metadata["_alpha"] self._beta = metadata["_beta"] @@ -992,7 +989,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, compile: bool | dict = False, - span: bool | int | Tuple[bool | int, bool | int] = False, + span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False, ): self.num_slices = num_slices @@ -1324,7 +1321,7 @@ def _adjusted_batch_size(self, batch_size): num_slices = batch_size // self.slice_len return seq_length, num_slices - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: if self._batch_size_multiplier is not None: batch_size = batch_size * self._batch_size_multiplier # pick up as many trajs as we need @@ -1361,7 +1358,7 @@ def _sample_slices( traj_idx: torch.Tensor | None = None, *, storage, - ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: + ) -> tuple[tuple[torch.Tensor, ...], dict[str, Any]]: # start_idx and stop_idx are 2d tensors organized like a non-zero def get_traj_idx(maxval): @@ -1442,7 +1439,7 @@ def _get_index( traj_idx: torch.Tensor | None = None, *, storage, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: # end_point is the last possible index for start last_indexable_start = lengths[traj_idx] - seq_length + 1 if not self.span[1]: @@ -1555,10 +1552,10 @@ def loads(self, path): # no op ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @@ -1787,7 +1784,7 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int - ) -> Tuple[Tuple[torch.Tensor, ...], dict]: + ) -> tuple[tuple[torch.Tensor, ...], dict]: if self._batch_size_multiplier is not None: batch_size = batch_size * self._batch_size_multiplier start_idx, stop_idx, lengths = self._get_stop_and_length(storage) @@ -1827,10 +1824,10 @@ def tuple_to_tensor(traj_idx, lengths=lengths): ) return idx, info - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return SamplerWithoutReplacement.state_dict(self) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return SamplerWithoutReplacement.load_state_dict(self, state_dict) @@ -1985,7 +1982,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, compile: bool | dict = False, - span: bool | int | Tuple[bool | int, bool | int] = False, + span: bool | int | tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False, ): SliceSampler.__init__( @@ -2045,7 +2042,7 @@ def __getstate__(self): return state def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: return PrioritizedSampler.mark_update(self, index, storage=storage) @@ -2111,7 +2108,7 @@ def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx): self._cache["preceding_stop_idx"] = preceding_stop_idx return preceding_stop_idx - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: # Sample `batch_size` indices representing the start of a slice. # The sampling is based on a weight vector. start_idx, stop_idx, lengths = self._get_stop_and_length(storage) @@ -2388,13 +2385,13 @@ def loads(self, path: Path): for i, sampler in enumerate(self._samplers): sampler.loads(path / str(i)) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: state_dict = OrderedDict() for i, sampler in enumerate(self._samplers): state_dict[str(i)] = sampler.state_dict() return state_dict - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for i, sampler in enumerate(self._samplers): sampler.load_state_dict(state_dict[str(i)]) diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py index 6829424c620..4c031cd6082 100644 --- a/torchrl/data/replay_buffers/scheduler.py +++ b/torchrl/data/replay_buffers/scheduler.py @@ -5,13 +5,10 @@ from __future__ import annotations from abc import ABC, abstractmethod - -from typing import Any, Callable, Dict +from typing import Any, Callable import numpy as np - import torch - from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler @@ -69,7 +66,7 @@ def state_dict(self): del sd["sampler"] return sd - def load_state_dict(self, state_dict: Dict[str, Any]): + def load_state_dict(self, state_dict: dict[str, Any]): """Load the scheduler's state. Args: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 89d7a4dbe43..bdbeb983c2c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,7 +5,6 @@ from __future__ import annotations import abc - import logging import os import textwrap @@ -13,17 +12,12 @@ from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Sequence import numpy as np import tensordict import torch -from tensordict import ( - is_tensor_collection, - LazyStackedTensorDict, - TensorDict, - TensorDictBase, -) +from tensordict import (LazyStackedTensorDict, TensorDict, TensorDictBase, is_tensor_collection) from tensordict.base import _NESTED_TENSORS_AS_LISTS from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _zip_strict @@ -36,12 +30,7 @@ StorageEnsembleCheckpointer, TensorStorageCheckpointer, ) -from torchrl.data.replay_buffers.utils import ( - _init_pytree, - _is_int, - INT_CLASSES, - tree_iter, -) +from torchrl.data.replay_buffers.utils import (INT_CLASSES, _init_pytree, _is_int, tree_iter) class Storage: @@ -86,7 +75,7 @@ def _is_full(self): return len(self) == self.max_size @property - def _attached_entities(self) -> List: + def _attached_entities(self) -> list: # RBs that use a given instance of Storage should add # themselves to this set. _attached_entities_list = getattr(self, "_attached_entities_list", None) @@ -142,11 +131,11 @@ def __len__(self): ... @abc.abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abc.abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @abc.abstractmethod @@ -241,7 +230,7 @@ def __init__(self, max_size: int | None = None, compilable: bool = False): def set( self, - cursor: Union[int, Sequence[int], slice], + cursor: int | Sequence[int] | slice, data: Any, *, set_cursor: bool = True, @@ -294,7 +283,7 @@ def set( else: self._storage[cursor] = data - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: if isinstance(index, (INT_CLASSES, slice)): return self._storage[index] elif isinstance(index, tuple): @@ -311,7 +300,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: def __len__(self): return len(self._storage) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_storage": [ elt if not hasattr(elt, "state_dict") else elt.state_dict() @@ -421,7 +410,7 @@ def __init__( super().__init__(max_size=max_size, compilable=compilable) self.stack_dim = stack_dim - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: out = super().get(index=index) if isinstance(out, list): stack_dim = self.stack_dim @@ -720,7 +709,7 @@ def __setstate__(self, state): state["_len_value"] = _len_value self.__dict__.update(state) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): pass @@ -794,8 +783,8 @@ def _get_new_len(self, data, cursor): @implement_for("torch", "2.0", None, compilable=True) def set( self, - cursor: Union[int, Sequence[int], slice], - data: Union[TensorDictBase, torch.Tensor], + cursor: int | Sequence[int] | slice, + data: TensorDictBase | torch.Tensor, *, set_cursor: bool = True, ): @@ -836,8 +825,8 @@ def set( @implement_for("torch", None, "2.0", compilable=True) def set( # noqa: F811 self, - cursor: Union[int, Sequence[int], slice], - data: Union[TensorDictBase, torch.Tensor], + cursor: int | Sequence[int] | slice, + data: TensorDictBase | torch.Tensor, *, set_cursor: bool = True, ): @@ -888,7 +877,7 @@ def set( # noqa: F811 ) self._storage[cursor] = data - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: _storage = self._storage is_tc = is_tensor_collection(_storage) if not self.initialized: @@ -1062,7 +1051,7 @@ def __init__( def _init( self, - data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 + data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 ) -> None: if not self._compilable: # TODO: Investigate why this seems to have a performance impact with @@ -1225,7 +1214,7 @@ def __init__( ) self._len = 0 - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): _storage = _mem_map_tensor_as_tensor(_storage) @@ -1282,7 +1271,7 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: + def _init(self, data: TensorDictBase | torch.Tensor) -> None: torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device @@ -1324,7 +1313,7 @@ def max_size_along_dim0(data_shape): self._storage = out self.initialized = True - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: result = super().get(index) return result @@ -1357,7 +1346,7 @@ class StorageEnsemble(Storage): def __init__( self, *storages: Storage, - transforms: List["Transform"] = None, # noqa: F821 + transforms: list[Transform] = None, # noqa: F821 ): self._rng_private = None self._storages = storages @@ -1408,10 +1397,10 @@ def _convert_id(self, sub): def _get_storage(self, sub): return self._storages[sub] - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead." diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 1e8985537f3..21c414e19bf 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -7,7 +7,6 @@ import contextlib import itertools - import math import operator import os @@ -93,7 +92,7 @@ def _pin_memory(output: Any) -> Any: def _reduce( tensor: torch.Tensor, reduction: str, dim: int | None = None -) -> Union[float, torch.Tensor]: +) -> float | torch.Tensor: """Reduces a tensor given the reduction method.""" if reduction == "max": result = tensor.max(dim=dim) @@ -977,15 +976,13 @@ def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None): # Copy-paste of unravel-index for PT 2.0 def _unravel_index( - indices: Tensor, shape: Union[int, typing.Sequence[int], torch.Size] -) -> typing.Tuple[Tensor, ...]: + indices: Tensor, shape: int | typing.Sequence[int] | torch.Size +) -> tuple[Tensor, ...]: res_tensor = _unravel_index_impl(indices, shape) return res_tensor.unbind(-1) -def _unravel_index_impl( - indices: Tensor, shape: Union[int, typing.Sequence[int]] -) -> Tensor: +def _unravel_index_impl(indices: Tensor, shape: int | typing.Sequence[int]) -> Tensor: if isinstance(shape, (int, torch.SymInt)): shape = torch.Size([shape]) else: diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index c7043f8829b..adb82d54b2e 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -11,12 +11,11 @@ from copy import copy from multiprocessing.context import get_spawning_popen from pathlib import Path -from typing import Any, Dict, Sequence +from typing import Any, Sequence import numpy as np import torch - -from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase +from tensordict import MemoryMappedTensor, TensorDictBase, is_tensor_collection from tensordict.utils import _STRDTYPE2DTYPE, expand_as_right, is_tensorclass from torch import multiprocessing as mp @@ -70,11 +69,11 @@ def loads(self, path): ... @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... def _replicate_index(self, index): @@ -131,10 +130,10 @@ def dumps(self, path): def loads(self, path): ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return @@ -160,7 +159,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "metadata.json", "r") as file: + with open(path / "metadata.json") as file: metadata = json.load(file) self._cursor = metadata["cursor"] @@ -209,10 +208,10 @@ def extend(self, data: Sequence) -> torch.Tensor: ent.mark_update(index) return index - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"_cursor": self._cursor} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._cursor = state_dict["_cursor"] def _empty(self): @@ -611,7 +610,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "metadata.json", "r") as file: + with open(path / "metadata.json") as file: metadata = json.load(file) self._cursor = metadata["cursor"] self._rank_key = metadata["rank_key"] @@ -623,10 +622,10 @@ def loads(self, path): shape=shape, ).tolist() - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError def __repr__(self): @@ -731,8 +730,8 @@ def __repr__(self): writers = textwrap.indent(f"writers={self._writers}", " " * 4) return f"WriterEnsemble(\n{writers})" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index a0905c2d063..346a6046566 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -7,13 +7,10 @@ import importlib.util import os from pathlib import Path - -from typing import Sequence, Type +from typing import Sequence import torch - from tensordict import TensorDict, TensorDictBase - from tensordict.utils import NestedKey from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers import ( @@ -94,7 +91,7 @@ def __init__( split, max_length, dataset_name, - tokenizer_fn: Type[TensorDictTokenizer], + tokenizer_fn: type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, @@ -227,7 +224,7 @@ def _tokenize( @staticmethod def dataset_to_tensordict( - dataset: "datasets.Dataset" | TensorDict, # noqa: F821 + dataset: datasets.Dataset | TensorDict, # noqa: F821 data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, @@ -320,7 +317,7 @@ def create_infinite_iterator(iterator): def get_dataloader( batch_size: int, block_size: int, - tensorclass_type: Type, + tensorclass_type: type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index 6f41fe48698..3b9755db69a 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -4,11 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional - import torch -from tensordict import tensorclass, TensorDict - +from tensordict import TensorDict, tensorclass from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader DEFAULT_DATASET = "CarperAI/openai_summarize_tldr" @@ -21,9 +18,9 @@ class PromptData: input_ids: torch.Tensor attention_mask: torch.Tensor prompt_rindex: torch.Tensor - labels: Optional[torch.Tensor] = None - logits: Optional[torch.Tensor] = None - loss: Optional[torch.Tensor] = None + labels: torch.Tensor | None = None + logits: torch.Tensor | None = None + loss: torch.Tensor | None = None def mask_label(self, pad_token_id=50256): _, block_size = self.input_ids.shape diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py index 98976984e27..12e2deef2b2 100644 --- a/torchrl/data/rlhf/reward.py +++ b/torchrl/data/rlhf/reward.py @@ -5,12 +5,9 @@ from __future__ import annotations import importlib -from typing import Optional import torch - from tensordict import tensorclass - from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons" @@ -24,8 +21,8 @@ class RewardData: input_ids: torch.Tensor attention_mask: torch.Tensor - rewards: Optional[torch.Tensor] = None - end_scores: Optional[torch.Tensor] = None + rewards: torch.Tensor | None = None + end_scores: torch.Tensor | None = None @tensorclass diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index bbde6761f4a..01d18ecec11 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,15 +7,12 @@ import abc import collections import importlib -from typing import List, Tuple import numpy as np import torch - from tensordict import TensorDict -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as F - from torchrl.data.rlhf.prompt import PromptData _has_transformers = importlib.util.find_spec("transformers") is not None @@ -30,7 +27,7 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: List[float]) -> float: + def update(self, kl_values: list[float]) -> float: ... @@ -63,7 +60,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: List[float] = None) -> float: + def update(self, kl_values: list[float] = None) -> float: if self.model is not None: self.model.kl_coef = self.coef return self.coef @@ -104,7 +101,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: List[float]): + def update(self, kl_values: list[float]): """Update ``self.coef`` adaptively. Arguments: @@ -422,7 +419,7 @@ def _default_conf(self): ) def _get_scores( - self, scores: Tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None + self, scores: tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None ): scores = torch.stack(scores, 1) if scores.shape[1] != self.max_new_tokens: diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index d43cbd7810d..1d3777eb48d 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -10,9 +10,7 @@ import numpy as np import torch - from torch import Tensor - from torchrl.data.tensor_specs import ( Binary, Categorical, @@ -139,7 +137,7 @@ def consolidate_spec( return spec -def _empty_like_spec(specs: List[TensorSpec], shape): +def _empty_like_spec(specs: list[TensorSpec], shape): for spec in specs[1:]: if spec.__class__ != specs[0].__class__: raise ValueError( @@ -224,7 +222,7 @@ def contains_lazy_spec(spec: TensorSpec) -> bool: return False -class CloudpickleWrapper(object): +class CloudpickleWrapper: """A wrapper for functions that allow for serialization in multiprocessed settings.""" def __init__(self, fn: Callable, **kwargs): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8be0d1c9a85..0969c694058 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -6,9 +6,7 @@ from __future__ import annotations import functools - import gc - import os import time import weakref @@ -17,39 +15,30 @@ from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence from warnings import warn import torch - -from tensordict import ( - is_tensor_collection, - LazyStackedTensorDict, - TensorDict, - TensorDictBase, - unravel_key, -) +from tensordict import (LazyStackedTensorDict, TensorDict, TensorDictBase, is_tensor_collection, unravel_key) from tensordict.base import _is_leaf_nontensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( + VERBOSE, + _ProcessNoWarn, _check_for_faulty_process, _make_ordinal_device, - _ProcessNoWarn, logger as torchrl_logger, - VERBOSE, ) from torchrl.data.tensor_specs import Composite, NonTensor -from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING, contains_lazy_spec +from torchrl.envs.common import EnvBase, EnvMetaData, _EnvPostInit, _do_nothing from torchrl.envs.env_creator import get_env_metadata - # legacy from torchrl.envs.libs.envpool import ( # noqa: F401 MultiThreadedEnv, MultiThreadedEnvWrapper, ) - from torchrl.envs.utils import ( _aggregate_end_of_traj, _sort_keys, @@ -100,7 +89,7 @@ def __iter__(self): class _dispatch_caller_serial: - def __init__(self, list_callable: List[Callable, Any]): + def __init__(self, list_callable: list[Callable, Any]): self.list_callable = list_callable def __call__(self, *args, **kwargs): @@ -296,15 +285,15 @@ class BatchedEnvBase(EnvBase): def __init__( self, num_workers: int, - create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]], + create_env_fn: Callable[[], EnvBase] | Sequence[Callable[[], EnvBase]], *, - create_env_kwargs: Union[dict, Sequence[dict]] = None, + create_env_kwargs: dict | Sequence[dict] = None, pin_memory: bool = False, - share_individual_td: Optional[bool] = None, + share_individual_td: bool | None = None, shared_memory: bool = True, memmap: bool = False, - policy_proof: Optional[Callable] = None, - device: Optional[DEVICE_TYPING] = None, + policy_proof: Callable | None = None, + device: DEVICE_TYPING | None = None, allow_step_when_done: bool = False, num_threads: int = None, num_sub_threads: int = 1, @@ -492,7 +481,7 @@ def _has_dynamic_specs(self): return not self._use_buffers def _get_metadata( - self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] + self, create_env_fn: list[Callable], create_env_kwargs: list[dict] ): if self._single_task: # if EnvCreator, the metadata are already there @@ -514,7 +503,7 @@ def _get_metadata( self.share_individual_td = False else: n_tasks = len(create_env_fn) - self.meta_data: List[EnvMetaData] = [] + self.meta_data: list[EnvMetaData] = [] for i in range(n_tasks): self.meta_data.append( get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone() @@ -541,7 +530,7 @@ def _get_metadata( self._set_properties() - def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: + def update_kwargs(self, kwargs: dict | list[dict]) -> None: """Updates the kwargs of each environment given a dictionary or a list of dictionaries. Args: @@ -873,7 +862,7 @@ def close(self, *, raise_if_closed: bool = True) -> None: def _shutdown_workers(self) -> None: raise NotImplementedError - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """This method is not used in batched envs.""" pass @@ -976,8 +965,8 @@ def _shutdown_workers(self) -> None: @_check_start def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: for env in self._envs: new_seed = env.set_seed(seed, static_seed=static_seed) seed = new_seed @@ -1538,7 +1527,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): @@ -1615,7 +1604,7 @@ def select_and_clone(x, y): @_check_start def step_and_maybe_reset( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: if not self._use_buffers: # Simply dispatch the input to the workers # return self._step_and_maybe_reset_no_buffers(tensordict) @@ -1825,7 +1814,7 @@ def _wait_for_workers(self, workers_range): def _step_no_buffers( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): @@ -2078,7 +2067,7 @@ def _reset_no_buffers( tensordict: TensorDictBase, reset_kwargs_list, needs_resetting, - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: if is_tensor_collection(tensordict): # tensordict = tensordict.consolidate(share_memory=True, num_threads=1) if self.consolidate: @@ -2271,8 +2260,8 @@ def _shutdown_workers(self) -> None: @_check_start def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: self._seeds = [] for channel in self.parent_channels: channel.send(("seed", (seed, static_seed))) @@ -2349,8 +2338,8 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered def _run_worker_pipe_shared_mem( parent_pipe: connection.Connection, child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], + env_fun: EnvBase | Callable, + env_fun_kwargs: dict[str, Any], mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -2601,8 +2590,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): def _run_worker_pipe_direct( parent_pipe: connection.Connection, child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], + env_fun: EnvBase | Callable, + env_fun_kwargs: dict[str, Any], mp_event: mp.Event = None, non_blocking: bool = False, has_lazy_inputs: bool = False, diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index ad8625a0418..945a377255e 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -7,15 +7,12 @@ import importlib.util import io import pathlib -from typing import Dict, Optional import torch from tensordict import TensorDict, TensorDictBase from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded - from torchrl.envs import EnvBase from torchrl.envs.common import _EnvPostInit - from torchrl.envs.utils import _classproperty @@ -49,10 +46,8 @@ def maybe_add_keys(condition, in_key, out_key): ) elif include_hash_inv: raise ValueError( - ( - "'include_hash_inv=True' can only be set if" - f"'include_hash=True', but got 'include_hash={include_hash}'." - ) + "'include_hash_inv=True' can only be set if" + f"'include_hash=True', but got 'include_hash={include_hash}'." ) if kwargs.get("mask_actions", True): from torchrl.envs import ActionMask @@ -197,7 +192,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): """ - _hash_table: Dict[int, str] = {} + _hash_table: dict[int, str] = {} _PGN_RESTART = """[Event "?"] [Site "?"] [Date "????.??.??"] @@ -231,7 +226,7 @@ def san_moves(cls): def _legal_moves_to_index( self, tensordict: TensorDictBase | None = None, - board: "chess.Board" | None = None, # noqa: F821 + board: chess.Board | None = None, # noqa: F821 return_mask: bool = False, pad: bool = False, ) -> torch.Tensor: @@ -357,16 +352,12 @@ def __init__( def _is_done(self, board): return board.is_game_over() | board.is_fifty_moves() - def all_actions( - self, tensordict: Optional[TensorDictBase] = None - ) -> TensorDictBase: + def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: if not self.mask_actions: raise RuntimeError( - ( - "Cannot generate legal actions since 'mask_actions=False' was " - "set. If you really want to generate all actions, not just " - "legal ones, call 'env.full_action_spec.enumerate()'." - ) + "Cannot generate legal actions since 'mask_actions=False' was " + "set. If you really want to generate all actions, not just " + "legal ones, call 'env.full_action_spec.enumerate()'." ) return super().all_actions(tensordict) @@ -480,8 +471,8 @@ def _get_tensor_image(cls, board): @classmethod def _pgn_to_board( - cls, pgn_string: str, board: "chess.Board" | None = None # noqa: F821 - ) -> "chess.Board": # noqa: F821 + cls, pgn_string: str, board: chess.Board | None = None # noqa: F821 + ) -> chess.Board: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if board is None: @@ -493,7 +484,7 @@ def _pgn_to_board( return board @classmethod - def _add_move_to_pgn(cls, pgn_string: str, move: "chess.Move") -> str: # noqa: F821 + def _add_move_to_pgn(cls, pgn_string: str, move: chess.Move) -> str: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if game is None: @@ -502,7 +493,7 @@ def _add_move_to_pgn(cls, pgn_string: str, move: "chess.Move") -> str: # noqa: return str(game) @classmethod - def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821 + def _board_to_pgn(cls, board: chess.Board) -> str: # noqa: F821 game = cls.lib.pgn.Game.from_board(board) pgn_string = str(game) return pgn_string diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 0a464a3e390..b5ec2cd9be0 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -4,11 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional - import torch from tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase @@ -279,7 +276,7 @@ def full(board: torch.Tensor) -> bool: def get_action_mask(): pass - def rand_action(self, tensordict: Optional[TensorDictBase] = None): + def rand_action(self, tensordict: TensorDictBase | None = None): mask = tensordict.get("mask") action_spec = self.action_spec if tensordict.ndim: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index f4cb8e263a1..0d81a8f7705 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -7,13 +7,11 @@ from collections import OrderedDict from multiprocessing.sharedctypes import Synchronized -from typing import Callable, Dict, Optional, Union +from typing import Callable import torch from tensordict import TensorDictBase - from torchrl._utils import logger as torchrl_logger - from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase, EnvMetaData @@ -80,7 +78,7 @@ class EnvCreator: def __init__( self, create_env_fn: Callable[..., EnvBase], - create_env_kwargs: Optional[Dict] = None, + create_env_kwargs: dict | None = None, share_memory: bool = True, **kwargs, ) -> None: @@ -230,9 +228,7 @@ def env_creator(fun: Callable) -> EnvCreator: return EnvCreator(fun) -def get_env_metadata( - env_or_creator: Union[EnvBase, Callable], kwargs: Optional[Dict] = None -): +def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = None): """Retrieves a EnvMetaData object from an env.""" if isinstance(env_or_creator, (EnvBase,)): return EnvMetaData.metadata_from_env(env_or_creator) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 857b3b96b2f..cba2b610648 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -8,15 +8,14 @@ import abc import re import warnings -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Mapping, Sequence import numpy as np import torch from tensordict import NonTensorData, TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger - from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded -from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase +from torchrl.envs.common import EnvBase, _EnvWrapper, _maybe_unlock class BaseInfoDictReader(metaclass=abc.ABCMeta): @@ -24,13 +23,13 @@ class BaseInfoDictReader(metaclass=abc.ABCMeta): @abc.abstractmethod def __call__( - self, info_dict: Dict[str, Any], tensordict: TensorDictBase + self, info_dict: dict[str, Any], tensordict: TensorDictBase ) -> TensorDictBase: raise NotImplementedError @property @abc.abstractmethod - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> dict[str, TensorSpec]: raise NotImplementedError @@ -67,8 +66,8 @@ class default_info_dict_reader(BaseInfoDictReader): def __init__( self, - keys: List[str] | None = None, - spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None, + keys: list[str] | None = None, + spec: Sequence[TensorSpec] | dict[str, TensorSpec] | Composite | None = None, ignore_private: bool = True, ): self.ignore_private = ignore_private @@ -98,7 +97,7 @@ def __init__( self._info_spec = _info_spec def __call__( - self, info_dict: Dict[str, Any], tensordict: TensorDictBase + self, info_dict: dict[str, Any], tensordict: TensorDictBase ) -> TensorDictBase: if not isinstance(info_dict, (dict, TensorDictBase)) and len(self.keys): warnings.warn( @@ -142,7 +141,7 @@ def reset(self): self._info_spec = None @property - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> dict[str, TensorSpec]: return self._info_spec @@ -166,7 +165,7 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ - _info_dict_reader: List[BaseInfoDictReader] + _info_dict_reader: list[BaseInfoDictReader] @classmethod def __new__(cls, *args, **kwargs): @@ -191,7 +190,7 @@ def read_done( terminated: bool | None = None, truncated: bool | None = None, done: bool | None = None, - ) -> Tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]: + ) -> tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]: """Done state reader. In torchrl, a `"done"` signal means that a trajectory has reach its end, @@ -257,8 +256,8 @@ def read_reward(self, reward): return reward def read_obs( - self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] - ) -> Dict[str, Any]: + self, observations: dict[str, Any] | torch.Tensor | np.ndarray + ) -> dict[str, Any]: """Reads an observation from the environment and returns an observation compatible with the output TensorDict. Args: @@ -371,7 +370,7 @@ def validated(self, value): self.__dict__["_validated"] = value def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: obs, info = self._reset_output_transform(self._env.reset(**kwargs)) @@ -398,8 +397,8 @@ def _reset( @abc.abstractmethod def _output_transform( - self, step_outputs_tuple: Tuple - ) -> Tuple[ + self, step_outputs_tuple: tuple + ) -> tuple[ Any, float | np.ndarray, bool | np.ndarray | None, @@ -434,7 +433,7 @@ def _output_transform( ... @abc.abstractmethod - def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple: ... @_maybe_unlock diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index ba1fdcfc9ae..a8f80aa6baf 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -5,16 +5,13 @@ from __future__ import annotations import collections - import importlib import os -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict import numpy as np import torch - -from torchrl._utils import logger as torchrl_logger, VERBOSE - +from torchrl._utils import VERBOSE, logger as torchrl_logger from torchrl.data.tensor_specs import ( Bounded, Categorical, @@ -23,7 +20,6 @@ TensorSpec, Unbounded, ) - from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict from torchrl.envs.gym_like import GymLikeEnv from torchrl.envs.utils import _classproperty @@ -41,7 +37,7 @@ def _dmcontrol_to_torchrl_spec_transform( spec, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, device: DEVICE_TYPING = None, categorical_discrete_encoding: bool = False, ) -> TensorSpec: @@ -94,7 +90,7 @@ def _dmcontrol_to_torchrl_spec_transform( raise NotImplementedError(type(spec)) -def _get_envs(to_dict: bool = True) -> Dict[str, Any]: +def _get_envs(to_dict: bool = True) -> dict[str, Any]: if not _has_dm_control: raise ImportError("Cannot find dm_control in virtual environment.") from dm_control import suite @@ -111,7 +107,7 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]: return d.items() -def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: +def _robust_to_tensor(array: float | np.ndarray) -> torch.Tensor: if isinstance(array, np.ndarray): return torch.as_tensor(array.copy()) else: @@ -211,11 +207,11 @@ def __init__(self, env=None, **kwargs): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -235,7 +231,7 @@ def _build_env( ) return env - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 # specs are defined when first called self.observation_spec = _dmcontrol_to_torchrl_spec_transform( self._env.observation_spec(), device=self.device @@ -260,7 +256,7 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 self._env.action_spec(), device=self.device ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): dm_control = self.lib from dm_control.suite.wrappers import pixels @@ -286,11 +282,11 @@ def to(self, device: DEVICE_TYPING) -> DMControlEnv: self._set_egl_device(self.device) return self - def _init_env(self, seed: Optional[int] = None) -> Optional[int]: + def _init_env(self, seed: int | None = None) -> int | None: seed = self.set_seed(seed) return seed - def _set_seed(self, _seed: Optional[int]) -> Optional[int]: + def _set_seed(self, _seed: int | None) -> int | None: from dm_control.suite.wrappers import pixels if _seed is None: @@ -308,8 +304,8 @@ def _set_seed(self, _seed: Optional[int]) -> Optional[int]: return _seed def _output_transform( - self, timestep_tuple: Tuple["TimeStep"] # noqa: F821 - ) -> Tuple[np.ndarray, float, bool, bool, dict]: + self, timestep_tuple: tuple[TimeStep] # noqa: F821 + ) -> tuple[np.ndarray, float, bool, bool, dict]: from dm_env import StepType if type(timestep_tuple) is not tuple: @@ -427,7 +423,7 @@ def _build_env( self, env_name: str, task_name: str, - _seed: Optional[int] = None, + _seed: int | None = None, **kwargs, ): from dm_control import suite @@ -467,7 +463,7 @@ def rebuild_with_kwargs(self, **new_kwargs): self._env = self._build_env() self._make_specs(self._env) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" in kwargs: env_name = kwargs["env_name"] if "task_name" in kwargs: diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index a4339820b9f..b59c7101c29 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -6,11 +6,10 @@ from __future__ import annotations import importlib -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded @@ -72,7 +71,7 @@ def lib(cls): def __init__( self, - env: Optional["envpool.python.envpool.EnvPoolMixin"] = None, # noqa: F821 + env: envpool.python.envpool.EnvPoolMixin | None = None, # noqa: F821 **kwargs, ): if not _has_envpool: @@ -88,9 +87,9 @@ def __init__( # Buffer to keep the latest observation for each worker # It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity" - self.obs: Union[torch.tensor, TensorDict] = self.observation_spec.zero() + self.obs: torch.tensor | TensorDict = self.observation_spec.zero() - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -99,11 +98,11 @@ def _check_kwargs(self, kwargs: Dict): if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)): raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.") - def _build_env(self, env: "envpool.python.envpool.EnvPoolMixin"): # noqa: F821 + def _build_env(self, env: envpool.python.envpool.EnvPoolMixin): # noqa: F821 return env def _make_specs( - self, env: "envpool.python.envpool.EnvPoolMixin" # noqa: F821 + self, env: envpool.python.envpool.EnvPoolMixin # noqa: F821 ) -> None: # noqa: F821 from torchrl.envs.libs.gym import set_gym_backend @@ -114,7 +113,7 @@ def _make_specs( self.reward_spec = output_spec["full_reward_spec"] self.done_spec = output_spec["full_done_spec"] - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: pass def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -212,10 +211,8 @@ def __repr__(self) -> str: def _transform_reset_output( self, - envpool_output: Tuple[ - Union["treevalue.TreeValue", np.ndarray], Any # noqa: F821 - ], - reset_workers: Optional[torch.Tensor], + envpool_output: tuple[treevalue.TreeValue | np.ndarray, Any], # noqa: F821 + reset_workers: torch.Tensor | None, ): """Process output of envpool env.reset.""" import treevalue @@ -243,7 +240,7 @@ def _transform_reset_output( return obs def _transform_step_output( - self, envpool_output: Tuple[Any, Any, Any, ...] + self, envpool_output: tuple[Any, Any, Any, ...] ) -> TensorDict: """Process output of envpool env.step.""" out = envpool_output @@ -272,8 +269,8 @@ def _transform_step_output( return tensordict_out def _treevalue_or_numpy_to_tensor_or_dict( - self, x: Union["treevalue.TreeValue", np.ndarray] # noqa: F821 - ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + self, x: treevalue.TreeValue | np.ndarray # noqa: F821 + ) -> torch.Tensor | dict[str, torch.Tensor]: """Converts observation returned by EnvPool. EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert @@ -291,8 +288,8 @@ def _treevalue_or_numpy_to_tensor_or_dict( return ret def _treevalue_to_dict( - self, tv: "treevalue.TreeValue" # noqa: F821 - ) -> Dict[str, Any]: + self, tv: treevalue.TreeValue # noqa: F821 + ) -> dict[str, Any]: """Converts TreeValue to a dictionary. Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary. @@ -301,7 +298,7 @@ def _treevalue_to_dict( return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)} - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): if seed is not None: torchrl_logger.info( "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ @@ -359,7 +356,7 @@ def __init__( num_workers: int, env_name: str, *, - create_env_kwargs: Optional[Dict[str, Any]] = None, + create_env_kwargs: dict[str, Any] | None = None, **kwargs, ): self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool @@ -376,7 +373,7 @@ def _build_env( self, env_name: str, num_workers: int, - create_env_kwargs: Optional[Dict[str, Any]], + create_env_kwargs: dict[str, Any] | None, ) -> Any: import envpool @@ -390,7 +387,7 @@ def _build_env( ) return super()._build_env(env) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """Library EnvPool only supports setting a seed by recreating the environment.""" if seed is not None: torchrl_logger.debug("Recreating EnvPool environment to set seed.") @@ -401,7 +398,7 @@ def _set_seed(self, seed: Optional[int]): create_env_kwargs=self.create_env_kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): for arg in ["num_workers", "env_name", "create_env_kwargs"]: if arg not in kwargs: raise TypeError(f"Expected '{arg}' to be part of kwargs") diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index d5bc3e3f4e8..47a718caf15 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -10,19 +10,16 @@ import warnings from copy import copy from types import ModuleType -from typing import Dict, List, Tuple +from typing import Dict, List from warnings import warn import numpy as np import torch from packaging import version - from tensordict import TensorDict, TensorDictBase from torch.utils._pytree import tree_map - from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( - _minmax_dtype, Binary, Bounded, Categorical, @@ -33,13 +30,12 @@ OneHot, TensorSpec, Unbounded, + _minmax_dtype, ) from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict from torchrl.envs.batched_envs import CloudpickleWrapper from torchrl.envs.common import _EnvPostInit - -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv - +from torchrl.envs.gym_like import GymLikeEnv, default_info_dict_reader from torchrl.envs.utils import _classproperty try: @@ -697,7 +693,7 @@ def _torchrl_to_gym_spec_transform( ) -def _get_envs(to_dict=False) -> List: +def _get_envs(to_dict=False) -> list: if not _has_gym: raise ImportError("Gym(nasium) could not be found in your virtual environment.") envs = _get_gym_envs() @@ -1052,7 +1048,7 @@ def _get_batch_size(self, env): # noqa: F811 def _get_batch_size(self, env): # noqa: F811 raise ImportError(GYMNASIUM_1_ERROR) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -1064,7 +1060,7 @@ def _build_env( env, from_pixels: bool = False, pixels_only: bool = False, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 self.batch_size = self._get_batch_size(env) env_from_pixels = _is_from_pixels(env) @@ -1230,7 +1226,7 @@ def _reward_space(self, env): # noqa: F811 rs = env.reward_space return rs - def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821 # If batch_size is provided, we se it to tell what batch size must be used # instead of self.batch_size cur_batch_size = self.batch_size if batch_size is None else torch.Size([]) @@ -1647,7 +1643,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 if not _has_gym: raise RuntimeError( f"gym not found, unable to create {env_name}. " @@ -1716,7 +1712,7 @@ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") @@ -1930,7 +1926,7 @@ def reset(self): self._final_validated = False -def _flip_info_tuple(info: Tuple[Dict]) -> Dict[str, tuple]: +def _flip_info_tuple(info: tuple[dict]) -> dict[str, tuple]: # In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples. # We patch this by flipping the tuple -> dict order. info_example = set(info[0]) diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 0a64c395126..4d1bbe60d43 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -5,14 +5,12 @@ from __future__ import annotations import importlib.util - import itertools import warnings -from typing import Any, Dict, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import TensorDictBase from torchrl.data import Composite from torchrl.envs.libs.gym import GymWrapper @@ -45,7 +43,7 @@ def lib(self): return isaacgym def __init__( - self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs # noqa: F821 + self, env: isaacgymenvs.tasks.base.vec_task.Env, **kwargs # noqa: F821 ): warnings.warn( "IsaacGym environment support is an experimental feature that may change in the future." @@ -57,7 +55,7 @@ def __init__( # by convention in IsaacGymEnvs self.task = env.__name__ - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 super()._make_specs(env, batch_size=self.batch_size) self.full_done_spec = Composite( { @@ -133,7 +131,7 @@ def read_done( terminated: bool = None, truncated: bool | None = None, done: bool | None = None, - ) -> Tuple[bool, bool, bool]: + ) -> tuple[bool, bool, bool]: if terminated is not None: terminated = terminated.bool() if truncated is not None: @@ -146,8 +144,8 @@ def read_reward(self, total_reward): return total_reward def read_obs( - self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] - ) -> Dict[str, Any]: + self, observations: dict[str, Any] | torch.Tensor | np.ndarray + ) -> dict[str, Any]: """Reads an observation from the environment and returns an observation compatible with the output TensorDict. Args: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 3bbaf7caa1c..e00b869e755 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -5,13 +5,11 @@ from __future__ import annotations import importlib.util -from typing import Dict, Optional, Tuple, Union import numpy as np import torch from packaging import version from tensordict import TensorDict, TensorDictBase - from torchrl.envs.common import _EnvPostInit from torchrl.envs.utils import _classproperty @@ -51,7 +49,7 @@ def _get_envs(): def _jumanji_to_torchrl_spec_transform( spec, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, device: DEVICE_TYPING = None, categorical_action_encoding: bool = True, ) -> TensorSpec: @@ -352,7 +350,7 @@ def lib(self): def __init__( self, - env: "jumanji.env.Environment" = None, # noqa: F821 + env: jumanji.env.Environment = None, # noqa: F821 categorical_action_encoding=True, jit: bool = True, **kwargs, @@ -388,11 +386,11 @@ def jit(self, value): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -480,7 +478,7 @@ def _make_reward_spec(self, env) -> TensorSpec: reward_spec.shape = torch.Size([1]) return reward_spec.expand([*self.batch_size, *reward_spec.shape]) - def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 + def _make_specs(self, env: jumanji.env.Environment) -> None: # noqa: F821 # extract spec from jumanji definition self.action_spec = self._make_action_spec(env) @@ -495,7 +493,7 @@ def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 # build state example for data conversion self._state_example = self._make_state_example(env) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): jumanji = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -674,7 +672,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: import jax from jax import numpy as jnp @@ -736,10 +734,10 @@ def read_reward(self, reward): return reward - def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: + def _output_transform(self, step_outputs_tuple: tuple) -> tuple: ... - def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple: ... @@ -938,7 +936,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "jumanji.env.Environment": # noqa: F821 + ) -> jumanji.env.Environment: # noqa: F821 if not _has_jumanji: raise ImportError( f"jumanji not found, unable to create {env_name}. " @@ -957,7 +955,7 @@ def _build_env( def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index ce6fd0bd179..c8ff477fa06 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -5,17 +5,14 @@ from __future__ import annotations import importlib - from typing import Dict, List, Mapping, Sequence import torch - from tensordict import TensorDict, TensorDictBase - from torchrl.data import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform -from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType +from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping _has_meltingpot = importlib.util.find_spec("meltingpot") is not None @@ -31,7 +28,7 @@ def _get_envs(): return list(substrate_configs.SUBSTRATES) -def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa +def _filter_global_state_from_dict(obs_dict: dict, world: bool) -> dict: # noqa return { key: value for key, value in obs_dict.items() @@ -40,8 +37,8 @@ def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa def _remove_world_observations_from_obs_spec( - observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]], # noqa -) -> Sequence[Mapping[str, "dm_env.specs.Array"]]: # noqa + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]], # noqa +) -> Sequence[Mapping[str, dm_env.specs.Array]]: # noqa return [ _filter_global_state_from_dict(agent_obs, world=False) for agent_obs in observation_spec @@ -49,8 +46,8 @@ def _remove_world_observations_from_obs_spec( def _global_state_spec_from_obs_spec( - observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]] # noqa -) -> Mapping[str, "dm_env.specs.Array"]: # noqa + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] # noqa +) -> Mapping[str, dm_env.specs.Array]: # noqa # We only look at agent 0 since world entries are the same for all agents world_entries = _filter_global_state_from_dict(observation_spec[0], world=True) if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries: @@ -60,7 +57,7 @@ def _global_state_spec_from_obs_spec( return _remove_world_prefix(world_entries) -def _remove_world_prefix(world_entries: Dict) -> Dict: +def _remove_world_prefix(world_entries: dict) -> dict: return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()} @@ -181,10 +178,10 @@ def available_envs(cls): def __init__( self, - env: "meltingpot.utils.substrates.substrate.Substrate" = None, # noqa + env: meltingpot.utils.substrates.substrate.Substrate = None, # noqa categorical_actions: bool = True, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, max_steps: int = None, **kwargs, ): @@ -198,7 +195,7 @@ def __init__( def _build_env( self, - env: "meltingpot.utils.substrates.substrate.Substrate", # noqa + env: meltingpot.utils.substrates.substrate.Substrate, # noqa ): return env @@ -208,7 +205,7 @@ def _make_group_map(self): check_marl_grouping(self.group_map, self.agent_names) def _make_specs( - self, env: "meltingpot.utils.substrates.substrate.Substrate" # noqa + self, env: meltingpot.utils.substrates.substrate.Substrate # noqa ) -> None: mp_obs_spec = self._env.observation_spec() # List of dict of arrays mp_obs_spec_no_world = _remove_world_observations_from_obs_spec( @@ -278,9 +275,9 @@ def _make_specs( def _make_group_specs( self, group: str, - torchrl_agent_obs_specs: List[TensorSpec], - torchrl_agent_act_specs: List[TensorSpec], - torchrl_rew_spec: List[TensorSpec], + torchrl_agent_obs_specs: list[TensorSpec], + torchrl_agent_act_specs: list[TensorSpec], + torchrl_rew_spec: list[TensorSpec], ): # Agent specs action_specs = [] @@ -327,7 +324,7 @@ def _make_group_specs( group_reward_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): meltingpot = self.lib if "env" not in kwargs: @@ -558,12 +555,12 @@ class MeltingpotEnv(MeltingpotWrapper): def __init__( self, - substrate: str | "ml_collections.config_dict.ConfigDict", # noqa + substrate: str | ml_collections.config_dict.ConfigDict, # noqa *, max_steps: int | None = None, categorical_actions: bool = True, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, **kwargs, ): if not _has_meltingpot: @@ -579,14 +576,14 @@ def __init__( **kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "substrate" not in kwargs: raise TypeError("Could not find environment key 'substrate' in kwargs.") def _build_env( self, - substrate: str | "ml_collections.config_dict.ConfigDict", # noqa - ) -> "meltingpot.utils.substrates.substrate.Substrate": # noqa + substrate: str | ml_collections.config_dict.ConfigDict, # noqa + ) -> meltingpot.utils.substrates.substrate.Substrate: # noqa from meltingpot import substrate as mp_substrate if isinstance(substrate, str): diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 3a2ab55cd13..5c248e476a9 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -6,11 +6,9 @@ from __future__ import annotations import importlib.util -from typing import Dict, List import torch from tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import ( Categorical, Composite, @@ -19,7 +17,7 @@ Unbounded, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType +from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping _has_pyspiel = importlib.util.find_spec("pyspiel") is not None @@ -159,7 +157,7 @@ def __init__( env=None, *, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions: bool = False, return_state: bool = False, **kwargs, @@ -176,7 +174,7 @@ def __init__( # `reset` allows resetting to any state, including a terminal state self._allow_done_after_reset = True - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): pyspiel = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -283,7 +281,7 @@ def _make_group_specs( group_reward_spec, ) - def _make_specs(self, env: "pyspiel.State") -> None: # noqa: F821 + def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821 self.agent_names = [f"player_{index}" for index in range(env.num_players())] self.agent_names_to_indices_map = { agent_name: i for i, agent_name in enumerate(self.agent_names) @@ -604,7 +602,7 @@ def __init__( game_string, *, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions=False, return_state: bool = False, **kwargs, @@ -621,7 +619,7 @@ def _build_env( self, game_string: str, **kwargs, - ) -> "pyspiel.State": # noqa: F821 + ) -> pyspiel.State: # noqa: F821 if not _has_pyspiel: raise ImportError( f"open_spiel not found, unable to create {game_string}. Consider " @@ -647,7 +645,7 @@ def _build_env( def game_string(self): return self._constructor_kwargs["game_string"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "game_string" not in kwargs: raise TypeError("Expected 'game_string' to be part of kwargs") diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 616362df804..3b35b1e1f32 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -7,17 +7,16 @@ import copy import importlib import warnings -from typing import Dict, List, Tuple, Union +from typing import Dict import numpy as np import packaging import torch from tensordict import TensorDictBase - from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend -from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType +from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping _has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None @@ -36,7 +35,7 @@ def _get_envs(): return list(all_environments.keys()) -def _load_available_envs() -> Dict: +def _load_available_envs() -> dict: all_environments = {} try: from pettingzoo.mpe.all_modules import mpe_environments @@ -73,9 +72,7 @@ def _load_available_envs() -> Dict: return all_environments -def _extract_nested_with_index( - data: Union[np.ndarray, Dict[str, np.ndarray]], index: int -): +def _extract_nested_with_index(data: np.ndarray | dict[str, np.ndarray], index: int): if isinstance(data, np.ndarray): return data[index] elif isinstance(data, dict): @@ -208,12 +205,12 @@ def available_envs(cls): def __init__( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ] = None, + env: ( + pettingzoo.utils.env.ParallelEnv + | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + ) = None, return_state: bool = False, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, use_mask: bool = False, categorical_actions: bool = True, seed: int | None = None, @@ -232,7 +229,7 @@ def __init__( super().__init__(**kwargs, allow_done_after_reset=True) - def _get_default_group_map(self, agent_names: List[str]): + def _get_default_group_map(self, agent_names: list[str]): # This function performs the default grouping in pettingzoo if not self.parallel: # In AEC envs we will have one group per agent by default @@ -273,10 +270,10 @@ def lib(self): def _build_env( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ], + env: ( + pettingzoo.utils.env.ParallelEnv + | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + ), ): import pettingzoo @@ -300,10 +297,10 @@ def _build_env( @set_gym_backend("gymnasium") def _make_specs( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ], + env: ( + pettingzoo.utils.env.ParallelEnv + | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + ), ) -> None: # Set default for done on any or all if self.done_on_any is None: @@ -359,7 +356,7 @@ def _make_specs( self.reward_spec = reward_spec self.done_spec = done_spec - def _make_group_specs(self, group_name: str, agent_names: List[str]): + def _make_group_specs(self, group_name: str, agent_names: list[str]): n_agents = len(agent_names) action_specs = [] observation_specs = [] @@ -457,7 +454,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): group_done_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): import pettingzoo if "env" not in kwargs: @@ -588,7 +585,7 @@ def _reset( return tensordict_out - def _reset_aec(self, **kwargs) -> Tuple[Dict, Dict]: + def _reset_aec(self, **kwargs) -> tuple[dict, dict]: self._env.reset(**kwargs) observation_dict = { @@ -597,7 +594,7 @@ def _reset_aec(self, **kwargs) -> Tuple[Dict, Dict]: info_dict = self._env.infos return observation_dict, info_dict - def _reset_parallel(self, **kwargs) -> Tuple[Dict, Dict]: + def _reset_parallel(self, **kwargs) -> tuple[dict, dict]: return self._env.reset(**kwargs) def _step( @@ -741,7 +738,7 @@ def _aggregate_done(self, tensordict_out, use_any): def _step_parallel( self, tensordict: TensorDictBase, - ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + ) -> tuple[dict, dict, dict, dict, dict]: action_dict = {} for group, agents in self.group_map.items(): group_action = tensordict.get((group, "action")) @@ -758,7 +755,7 @@ def _step_parallel( def _step_aec( self, tensordict: TensorDictBase, - ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + ) -> tuple[dict, dict, dict, dict, dict]: for group, agents in self.group_map.items(): if self.agent_selection in agents: agent_index = agents.index(self._env.agent_selection) @@ -966,7 +963,7 @@ def __init__( task: str, parallel: bool, return_state: bool = False, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, use_mask: bool = False, categorical_actions: bool = True, seed: int | None = None, @@ -989,7 +986,7 @@ def __init__( super().__init__(**kwargs) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "task" not in kwargs: raise TypeError("Could not find environment key 'task' in kwargs.") if "parallel" not in kwargs: @@ -1000,10 +997,10 @@ def _build_env( task: str, parallel: bool, **kwargs, - ) -> Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ]: + ) -> ( + pettingzoo.utils.env.ParallelEnv + | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + ): self.task_name = task try: diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 2a4e04f7d71..f471dee382c 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -5,7 +5,6 @@ import importlib import os import warnings - from copy import copy from pathlib import Path @@ -14,12 +13,7 @@ from tensordict import TensorDict from torchrl.data.tensor_specs import Unbounded from torchrl.envs.common import _maybe_unlock -from torchrl.envs.libs.gym import ( - _gym_to_torchrl_spec_transform, - _GymAsyncMeta, - gym_backend, - GymEnv, -) +from torchrl.envs.libs.gym import (GymEnv, _GymAsyncMeta, _gym_to_torchrl_spec_transform, gym_backend) from torchrl.envs.utils import _classproperty, make_composite_from_td _has_gym = ( @@ -32,7 +26,7 @@ os.environ.setdefault("sim_backend", "MUJOCO") -class set_directory(object): +class set_directory: """Sets the cwd within the context. Args: diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 397abbcc3c0..5f5fc6c5654 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -6,11 +6,9 @@ from __future__ import annotations import importlib.util -from typing import Dict, List, Optional import torch from tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import ( BoundedContinuous, Categorical, @@ -20,7 +18,7 @@ Unbounded, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType +from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping _has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None @@ -87,7 +85,6 @@ def lib(cls): if cls._lib is not None: return cls._lib - import mlagents_envs import mlagents_envs.environment cls._lib = mlagents_envs @@ -97,7 +94,7 @@ def __init__( self, env=None, *, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, categorical_actions: bool = False, **kwargs, ): @@ -108,7 +105,7 @@ def __init__( self.categorical_actions = categorical_actions super().__init__(**kwargs) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): mlagents_envs = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -177,7 +174,7 @@ def _make_group_map(self, group_map, agent_name_to_group_id_map): return group_map, agent_name_to_group_name_map def _make_specs( - self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821 + self, env: mlagents_envs.environment.UnityEnvironment # noqa: F821 ) -> None: # NOTE: We need to reset here because mlagents only initializes the # agents and behaviors after reset. In order to build specs, we make the @@ -288,17 +285,13 @@ def _set_seed(self, seed): def _check_agent_exists(self, agent_name, group_id): if agent_name not in self.agent_name_to_group_id_map: raise RuntimeError( - ( - "Unity environment added a new agent. This is not yet " - "supported in torchrl." - ) + "Unity environment added a new agent. This is not yet " + "supported in torchrl." ) if self.agent_name_to_group_id_map[agent_name] != group_id: raise RuntimeError( - ( - "Unity environment changed the group of an agent. This " - "is not yet supported in torchrl." - ) + "Unity environment changed the group of an agent. This " + "is not yet supported in torchrl." ) def _update_action_mask(self): @@ -836,10 +829,10 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper): def __init__( self, - file_name: Optional[str] = None, - registered_name: Optional[str] = None, + file_name: str | None = None, + registered_name: str | None = None, *, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, categorical_actions=False, **kwargs, ): @@ -853,10 +846,10 @@ def __init__( def _build_env( self, - file_name: Optional[str], - registered_name: Optional[str], + file_name: str | None, + registered_name: str | None, **kwargs, - ) -> "mlagents_envs.environment.UnityEnvironment": # noqa: F821 + ) -> mlagents_envs.environment.UnityEnvironment: # noqa: F821 if not _has_unity_mlagents: raise ImportError( "mlagents_envs not found, unable to create environment. " @@ -888,7 +881,7 @@ def file_name(self): def registered_name(self): return self._constructor_kwargs["registered_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): pass def __repr__(self) -> str: diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 1d786358ca9..6e1bb1237e6 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -6,12 +6,10 @@ import importlib.util import warnings - -from typing import Dict, List, Optional, Union +from typing import Dict, List import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase - from torchrl.data.tensor_specs import ( Bounded, Categorical, @@ -25,14 +23,9 @@ Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict -from torchrl.envs.common import _EnvWrapper, EnvBase +from torchrl.envs.common import EnvBase, _EnvWrapper from torchrl.envs.libs.gym import gym_backend, set_gym_backend -from torchrl.envs.utils import ( - _classproperty, - _selective_unsqueeze, - check_marl_grouping, - MarlGroupMapType, -) +from torchrl.envs.utils import (MarlGroupMapType, _classproperty, _selective_unsqueeze, check_marl_grouping) _has_vmas = importlib.util.find_spec("vmas") is not None @@ -237,9 +230,9 @@ def available_envs(cls): def __init__( self, - env: "vmas.simulator.environment.environment.Environment" = None, # noqa + env: vmas.simulator.environment.environment.Environment = None, # noqa categorical_actions: bool = True, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, **kwargs, ): if env is not None: @@ -253,7 +246,7 @@ def __init__( def _build_env( self, - env: "vmas.simulator.environment.environment.Environment", # noqa + env: vmas.simulator.environment.environment.Environment, # noqa from_pixels: bool = False, pixels_only: bool = False, ): @@ -281,7 +274,7 @@ def _build_env( return env - def _get_default_group_map(self, agent_names: List[str]): + def _get_default_group_map(self, agent_names: list[str]): # This function performs the default grouping in vmas. # Agents with names "_" will be grouped in group name "". # If any of the agents does not follow the naming convention, we fall back @@ -316,7 +309,7 @@ def _get_default_group_map(self, agent_names: List[str]): return group_map def _make_specs( - self, env: "vmas.simulator.environment.environment.Environment" # noqa + self, env: vmas.simulator.environment.environment.Environment # noqa ) -> None: # Create and check group map self.agent_names = [agent.name for agent in self.agents] @@ -478,7 +471,7 @@ def _make_unbatched_group_specs(self, group: str): group_info_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): vmas = self.lib if "env" not in kwargs: @@ -489,14 +482,14 @@ def _check_kwargs(self, kwargs: Dict): "env is not of type 'vmas.simulator.environment.Environment'." ) - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: pass - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self._env.seed(seed) def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): _reset = tensordict.get("_reset") @@ -607,9 +600,7 @@ def _step( ) return tensordict_out - def read_obs( - self, observations: Union[Dict, torch.Tensor] - ) -> Union[Dict, torch.Tensor]: + def read_obs(self, observations: dict | torch.Tensor) -> dict | torch.Tensor: if isinstance(observations, torch.Tensor): return _selective_unsqueeze(observations, batch_size=self.batch_size) return TensorDict( @@ -617,7 +608,7 @@ def read_obs( batch_size=self.batch_size, ) - def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: + def read_info(self, infos: dict[str, torch.Tensor]) -> torch.Tensor: if len(infos) == 0: return None infos = TensorDict( @@ -777,14 +768,14 @@ class VmasEnv(VmasWrapper): def __init__( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa + scenario: str | vmas.simulator.scenario.BaseScenario, # noqa *, num_envs: int, continuous_actions: bool = True, - max_steps: Optional[int] = None, + max_steps: int | None = None, categorical_actions: bool = True, - seed: Optional[int] = None, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + seed: int | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, **kwargs, ): if not _has_vmas: @@ -803,7 +794,7 @@ def __init__( **kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "scenario" not in kwargs: raise TypeError("Could not find environment key 'scenario' in kwargs.") if "num_envs" not in kwargs: @@ -811,13 +802,13 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa + scenario: str | vmas.simulator.scenario.BaseScenario, # noqa num_envs: int, continuous_actions: bool, - max_steps: Optional[int], - seed: Optional[int], + max_steps: int | None, + seed: int | None, **scenario_kwargs, - ) -> "vmas.simulator.environment.environment.Environment": # noqa + ) -> vmas.simulator.environment.environment.Environment: # noqa vmas = self.lib self.scenario_name = scenario diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index d2e7a6271e5..b2b64837808 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -10,7 +10,6 @@ import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule - from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -119,7 +118,7 @@ def __init__( batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, ): - super(ModelBasedEnvBase, self).__init__( + super().__init__( device=device, batch_size=batch_size, run_type_checks=run_type_checks, diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index b69f206bb01..67e863c4799 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -8,7 +8,6 @@ import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule - from torchrl.data.tensor_specs import Composite from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -28,9 +27,7 @@ def __init__( device: DEVICE_TYPING = "cpu", batch_size: Optional[torch.Size] = None, ): - super(DreamerEnv, self).__init__( - world_model, device=device, batch_size=batch_size - ) + super().__init__(world_model, device=device, batch_size=batch_size) self.obs_decoder = obs_decoder self.prior_shape = prior_shape self.belief_shape = belief_shape diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index 6ef23b11fd5..80bd1e089a0 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -15,7 +15,7 @@ def _get_image_num_channels(img: Tensor) -> int: elif img.ndim > 2: return img.shape[-3] - raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) + raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") def _assert_channels(img: Tensor, permitted: List[int]) -> None: diff --git a/torchrl/envs/transforms/rb_transforms.py b/torchrl/envs/transforms/rb_transforms.py index 76a8e6039f8..8507ce6d8f3 100644 --- a/torchrl/envs/transforms/rb_transforms.py +++ b/torchrl/envs/transforms/rb_transforms.py @@ -4,10 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import List - import torch - from tensordict import NestedKey, TensorDictBase from torchrl.data.postprocs.postprocs import _multi_step_func from torchrl.envs.transforms.transforms import Transform @@ -112,9 +109,9 @@ def __init__( n_steps, gamma, *, - reward_keys: List[NestedKey] | None = None, + reward_keys: list[NestedKey] | None = None, done_key: NestedKey | None = None, - done_keys: List[NestedKey] | None = None, + done_keys: list[NestedKey] | None = None, mask_key: NestedKey | None = None, ): super().__init__() diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 7d64ada37c4..0b9afecff02 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -6,9 +6,8 @@ from typing import List, Optional, Union import torch -from tensordict import set_lazy_legacy, TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase, set_lazy_legacy from torch.hub import load_state_dict_from_url - from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( @@ -364,7 +363,7 @@ def _embed_goal(self, tensordict): tensordict_in = tensordict.select("goal_image").rename_key_( "goal_image", self.in_keys[0] ) - tensordict_in = super(VIPRewardTransform, self).forward(tensordict_in) + tensordict_in = super().forward(tensordict_in) tensordict = tensordict.update( tensordict_in.rename_key_(self.out_keys[0], "goal_embedding") ) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 3dbb42e9cef..460a712d3a2 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -6,23 +6,20 @@ import weakref from numbers import Number -from typing import Dict, Optional, Sequence, Union +from typing import Sequence import numpy as np import torch from packaging import version from torch import distributions as D, nn - from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform - from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, ) - from torchrl.modules.distributions.utils import ( - _cast_device, FasterTransformedDistribution, + _cast_device, safeatanh_noeps, safetanh_noeps, ) @@ -179,9 +176,9 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - upscale: Union[torch.Tensor, float] = 5.0, - low: Union[torch.Tensor, float] = -1.0, - high: Union[torch.Tensor, float] = 1.0, + upscale: torch.Tensor | float = 5.0, + low: torch.Tensor | float = -1.0, + high: torch.Tensor | float = 1.0, tanh_loc: bool = False, ): @@ -345,9 +342,9 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - upscale: Union[torch.Tensor, Number] = 5.0, - low: Union[torch.Tensor, Number] = -1.0, - high: Union[torch.Tensor, Number] = 1.0, + upscale: torch.Tensor | Number = 5.0, + low: torch.Tensor | Number = -1.0, + high: torch.Tensor | Number = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, @@ -543,15 +540,15 @@ class Delta(D.Distribution): """ - arg_constraints: Dict = {} + arg_constraints: dict = {} def __init__( self, param: torch.Tensor, atol: float = 1e-6, rtol: float = 1e-6, - batch_shape: Union[torch.Size, Sequence[int]] = None, - event_shape: Union[torch.Size, Sequence[int]] = None, + batch_shape: torch.Size | Sequence[int] = None, + event_shape: torch.Size | Sequence[int] = None, ): if batch_shape is None: batch_shape = torch.Size([]) @@ -640,8 +637,8 @@ class TanhDelta(FasterTransformedDistribution): def __init__( self, param: torch.Tensor, - low: Union[torch.Tensor, float] = -1.0, - high: Union[torch.Tensor, float] = 1.0, + low: torch.Tensor | float = -1.0, + high: torch.Tensor | float = 1.0, event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, @@ -714,7 +711,7 @@ def max(self): self._warn_minmax() return self.high - def update(self, net_output: torch.Tensor) -> Optional[torch.Tensor]: + def update(self, net_output: torch.Tensor) -> torch.Tensor | None: loc = net_output if self.non_trivial: device = loc.device diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 1350aeb2bc3..f2ccf04c0fe 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -10,7 +10,7 @@ from numbers import Number import torch -from torch.distributions import constraints, Distribution +from torch.distributions import Distribution, constraints from torch.distributions.utils import broadcast_all CONST_SQRT_2 = math.sqrt(2) @@ -41,9 +41,7 @@ def __init__(self, a, b, validate_args=None, device=None): batch_shape = torch.Size() else: batch_shape = self.a.size() - super(TruncatedStandardNormal, self).__init__( - batch_shape, validate_args=validate_args - ) + super().__init__(batch_shape, validate_args=validate_args) if self.a.dtype != self.b.dtype: raise ValueError("Truncation bounds types are different") if any( @@ -154,7 +152,7 @@ def __init__(self, loc, scale, a, b, validate_args=None, device=None): self._non_std_b = b a = (a - self.loc) / self.scale b = (b - self.loc) / self.scale - super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) + super().__init__(a, b, validate_args=validate_args) self._log_scale = self.scale.log() self._mean = self._mean * self.scale + self.loc self._variance = self._variance * self.scale**2 @@ -167,7 +165,7 @@ def _from_std_rv(self, value): return value * self.scale + self.loc def cdf(self, value): - return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + return super().cdf(self._to_std_rv(value)) def icdf(self, value): sample = self._from_std_rv(super().icdf(value)) @@ -184,4 +182,4 @@ def icdf(self, value): def log_prob(self, value): value = self._to_std_rv(value) - return super(TruncatedNormal, self).log_prob(value) - self._log_scale + return super().log_prob(value) - self._log_scale diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 8c332c4efed..fb64aa851d3 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -55,7 +55,7 @@ def __init__(self, base_distribution, transforms, validate_args=None): raise ValueError("Make a ComposeTransform first.") else: raise ValueError( - "transforms must be a Transform or list, but was {}".format(transforms) + f"transforms must be a Transform or list, but was {transforms}" ) transform = self.transforms[0] # Reshape base_distribution according to transforms. diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index d0c40b4fbc4..923bfb9662d 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses - import importlib from contextlib import nullcontext from dataclasses import dataclass @@ -113,7 +112,7 @@ def __init__( f"Config of type {type(config)} is not supported." ) from err - super(DecisionTransformer, self).__init__() + super().__init__() with torch.device(device) if device is not None else nullcontext(): gpt_config = transformers.GPT2Config( diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 571ace39bc7..d5ac3157b8d 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -7,10 +7,9 @@ import functools import math import warnings -from typing import List, Optional, Sequence, Union +from typing import Sequence import torch - from tensordict.nn import TensorDictModuleBase from tensordict.utils import NestedKey from torch import distributions as d, nn @@ -21,7 +20,7 @@ from torchrl._utils import prod from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS -from torchrl.envs.utils import exploration_type, ExplorationType +from torchrl.envs.utils import ExplorationType, exploration_type from torchrl.modules.distributions.utils import _cast_transform_device from torchrl.modules.utils import inv_softplus @@ -56,8 +55,8 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[torch.dtype] = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | None = None, std_init: float = 0.1, ): nn.Module.__init__(self) @@ -128,7 +127,7 @@ def reset_noise(self) -> None: if self.bias_mu is not None: self.bias_epsilon.copy_(epsilon_out) - def _scale_noise(self, size: Union[int, torch.Size, Sequence]) -> torch.Tensor: + def _scale_noise(self, size: int | torch.Size | Sequence) -> torch.Tensor: if isinstance(size, int): size = (size,) x = torch.randn(*size, device=self.weight_mu.device) @@ -142,7 +141,7 @@ def weight(self) -> torch.Tensor: return self.weight_mu @property - def bias(self) -> Optional[torch.Tensor]: + def bias(self) -> torch.Tensor | None: if self.bias_mu is not None: if self.training: return self.bias_mu + self.bias_sigma * self.bias_epsilon @@ -177,8 +176,8 @@ def __init__( self, out_features: int, bias: bool = True, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[torch.dtype] = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | None = None, std_init: float = 0.1, ): super().__init__(0, 0, False, device=device) @@ -323,8 +322,8 @@ def __init__( scale_min: float = 0.01, scale_max: float = 10.0, learn_sigma: bool = True, - transform: Optional[d.Transform] = None, - device: Optional[DEVICE_TYPING] = None, + transform: d.Transform | None = None, + device: DEVICE_TYPING | None = None, ) -> None: super().__init__() self.action_dim = action_dim @@ -416,7 +415,7 @@ def forward(self, mu, state, _eps_gSDE): action = self.transform(action) return mu, sigma, action, _eps_gSDE - def to(self, device_or_dtype: Union[torch.dtype, DEVICE_TYPING]): + def to(self, device_or_dtype: torch.dtype | DEVICE_TYPING): if isinstance(device_or_dtype, DEVICE_TYPING_ARGS): self.transform = _cast_transform_device(self.transform, device_or_dtype) return super().to(device_or_dtype) @@ -458,8 +457,8 @@ def __init__( scale_min: float = 0.01, scale_max: float = 10.0, learn_sigma: bool = True, - transform: Optional[d.Transform] = None, - device: Optional[DEVICE_TYPING] = None, + transform: d.Transform | None = None, + device: DEVICE_TYPING | None = None, ) -> None: super().__init__( 0, @@ -642,8 +641,8 @@ class ConsistentDropoutModule(TensorDictModuleBase): def __init__( self, p: float, - in_keys: NestedKey | List[NestedKey], - out_keys: NestedKey | List[NestedKey] | None = None, + in_keys: NestedKey | list[NestedKey], + out_keys: NestedKey | list[NestedKey] | None = None, input_shape: torch.Size = None, input_dtype: torch.dtype | None = None, ): diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 36a11a508f5..711e7cda87f 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -5,24 +5,22 @@ from __future__ import annotations import dataclasses - from copy import deepcopy from numbers import Number -from typing import Callable, Dict, List, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Sequence, Type import torch from torch import nn - from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING from torchrl.modules.models.decision_transformer import DecisionTransformer from torchrl.modules.models.utils import ( - _find_depth, - create_on_device, LazyMapping, SquashDims, Squeeze2dLayer, SqueezeLayer, + _find_depth, + create_on_device, ) from torchrl.modules.tensordict_module.common import DistributionalDQNnet # noqa @@ -165,14 +163,14 @@ def __init__( out_features: int | torch.Size = None, depth: int | None = None, num_cells: Sequence[int] | int | None = None, - activation_class: Type[nn.Module] | Callable = nn.Tanh, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.Tanh, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, dropout: float | None = None, bias_last_layer: bool = True, single_bias_last_layer: bool = False, - layer_class: Type[nn.Module] | Callable = nn.Linear, + layer_class: type[nn.Module] | Callable = nn.Linear, layer_kwargs: dict | None = None, activate_last_layer: bool = False, device: DEVICE_TYPING | None = None, @@ -244,7 +242,7 @@ def __init__( ] super().__init__(*layers) - def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: + def _make_net(self, device: DEVICE_TYPING | None) -> list[nn.Module]: layers = [] in_features = [self.in_features] + self.num_cells out_features = self.num_cells + [self._out_features_num] @@ -293,7 +291,7 @@ def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: return layers - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = (torch.cat([*inputs], -1),) @@ -408,15 +406,15 @@ def __init__( in_features: int | None = None, depth: int | None = None, num_cells: Sequence[int] | int = None, - kernel_sizes: Union[Sequence[int], int] = 3, + kernel_sizes: Sequence[int] | int = 3, strides: Sequence[int] | int = 1, paddings: Sequence[int] | int = 0, - activation_class: Type[nn.Module] | Callable = nn.ELU, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_class: type[nn.Module] | Callable | None = SquashDims, aggregator_kwargs: dict | None = None, squeeze_output: bool = False, device: DEVICE_TYPING | None = None, @@ -540,7 +538,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: *batch, C, L, W = inputs.shape if len(batch) > 1: inputs = inputs.flatten(0, len(batch) - 1) - out = super(ConvNet, self).forward(inputs) + out = super().forward(inputs) if len(batch) > 1: out = out.unflatten(0, batch) return out @@ -678,12 +676,12 @@ def __init__( kernel_sizes: Sequence[int] | int = 3, strides: Sequence[int] | int = 1, paddings: Sequence[int] | int = 0, - activation_class: Type[nn.Module] | Callable = nn.ELU, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_class: type[nn.Module] | Callable | None = SquashDims, aggregator_kwargs: dict | None = None, squeeze_output: bool = False, device: DEVICE_TYPING | None = None, @@ -1199,7 +1197,7 @@ def __init__( self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp, 6e-4, device=device) - def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, observation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden = self.convnet(observation) action = self.mlp(hidden) return action, hidden @@ -1478,7 +1476,7 @@ def __init__( "bias_last_layer": True, "activate_last_layer": True, } - mlp_net_kwargs_net1: Dict = ( + mlp_net_kwargs_net1: dict = ( mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else {} ) mlp1_net_default_kwargs.update(mlp_net_kwargs_net1) @@ -1539,7 +1537,7 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict | DecisionTransformer.DTConfig = None, + transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None, ): super().__init__() @@ -1577,7 +1575,7 @@ def forward( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_state = self.transformer(observation, action, return_to_go) mu = self.action_layer_mean(hidden_state) log_std = self.action_layer_logstd(hidden_state) @@ -1638,7 +1636,7 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict | DecisionTransformer.DTConfig = None, + transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None, ): super().__init__() @@ -1690,7 +1688,7 @@ def default_config(cls): ) -def _iter_maybe_over_single(item: dict | List[dict] | None, n): +def _iter_maybe_over_single(item: dict | list[dict] | None, n): if item is None: return iter([{} for _ in range(n)]) elif isinstance(item, dict): @@ -1703,7 +1701,7 @@ class _ExecutableLayer(nn.Module): """A thin wrapper around a function to be executed as a module.""" def __init__(self, func): - super(_ExecutableLayer, self).__init__() + super().__init__() self.func = func def forward(self, *args, **kwargs): diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 71b5c254d0a..8dd7cbe8743 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -7,16 +7,13 @@ import abc from copy import deepcopy from textwrap import indent -from typing import Optional, Sequence, Tuple, Type, Union +from typing import Sequence, Type import numpy as np - import torch - from tensordict import TensorDict from torch import nn from torchrl.data.utils import DEVICE_TYPING - from torchrl.modules.models import ConvNet, MLP from torchrl.modules.models.utils import _reset_parameters_recursive @@ -129,7 +126,7 @@ def exec_module(params, *input): return torch.vmap(exec_module, *args, **kwargs) - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = torch.cat([*inputs], -1) else: @@ -418,10 +415,10 @@ def __init__( *, centralized: bool | None = None, share_params: bool | None = None, - device: Optional[DEVICE_TYPING] = None, - depth: Optional[int] = None, - num_cells: Optional[Union[Sequence, int]] = None, - activation_class: Optional[Type[nn.Module]] = nn.Tanh, + device: DEVICE_TYPING | None = None, + depth: int | None = None, + num_cells: Sequence | int | None = None, + activation_class: type[nn.Module] | None = nn.Tanh, use_td_params: bool = True, **kwargs, ): @@ -631,10 +628,10 @@ def __init__( in_features: int | None = None, device: DEVICE_TYPING | None = None, num_cells: Sequence[int] | None = None, - kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, - strides: Union[Sequence, int] = 2, - paddings: Union[Sequence, int] = 0, - activation_class: Type[nn.Module] = nn.ELU, + kernel_sizes: Sequence[int | Sequence[int]] | int = 5, + strides: Sequence | int = 2, + paddings: Sequence | int = 0, + activation_class: type[nn.Module] = nn.ELU, use_td_params: bool = True, **kwargs, ): @@ -789,7 +786,7 @@ def __init__( self, n_agents: int, needs_state: bool, - state_shape: Union[Tuple[int, ...], torch.Size], + state_shape: tuple[int, ...] | torch.Size, device: DEVICE_TYPING, ): super().__init__() @@ -799,7 +796,7 @@ def __init__( self.needs_state = needs_state self.state_shape = state_shape - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: """Forward pass of the mixer. Args: @@ -1001,7 +998,7 @@ class QMixer(Mixer): def __init__( self, - state_shape: Union[Tuple[int, ...], torch.Size], + state_shape: tuple[int, ...] | torch.Size, mixing_embed_dim: int, n_agents: int, device: DEVICE_TYPING, diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 5a59bc55fa1..da92caa08dd 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -17,7 +17,7 @@ def __init__( self, num_ch, ): - super(_ResNetBlock, self).__init__() + super().__init__() resnet_block = [] resnet_block.append(nn.ReLU(inplace=True)) resnet_block.append( diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 0c650087235..510ee423a64 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -10,9 +10,7 @@ import torch from torch import nn - from torchrl.data.utils import DEVICE_TYPING - from torchrl.modules.models.exploration import NoisyLazyLinear, NoisyLinear LazyMapping = { @@ -114,7 +112,7 @@ def _find_depth(depth: int | None, *list_or_ints: Sequence): def create_on_device( - module_class: Type[nn.Module] | Callable, + module_class: type[nn.Module] | Callable, device: DEVICE_TYPING | None, *args, **kwargs, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index ca76acc4160..4a33d0d7fbb 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,23 +4,21 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence import torch - from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( CompositeDistribution, - dispatch, TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, TensorDictSequential, + dispatch, ) -from tensordict.utils import expand_as_right, NestedKey +from tensordict.utils import NestedKey, expand_as_right from torch import nn from torch.distributions import Categorical - from torchrl._utils import _replace_last from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _process_action_space_spec @@ -98,10 +96,10 @@ class Actor(SafeModule): def __init__( self, module: nn.Module, - in_keys: Optional[Sequence[NestedKey]] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, *, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, **kwargs, ): if in_keys is None: @@ -360,10 +358,10 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): def __init__( self, module: TensorDictModule, - in_keys: Union[NestedKey, Sequence[NestedKey]], - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: NestedKey | Sequence[NestedKey], + out_keys: Sequence[NestedKey] | None = None, *, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, **kwargs, ): distribution_class = kwargs.get("distribution_class") @@ -450,8 +448,8 @@ class ValueOperator(TensorDictModule): def __init__( self, module: nn.Module, - in_keys: Optional[Sequence[NestedKey]] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, ) -> None: if in_keys is None: in_keys = ["observation"] @@ -532,12 +530,12 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: Optional[str] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, - var_nums: Optional[int] = None, - spec: Optional[TensorSpec] = None, + action_space: str | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, + var_nums: int | None = None, + spec: TensorSpec | None = None, safe: bool = False, ): if isinstance(action_space, TensorSpec): @@ -748,12 +746,12 @@ class DistributionalQValueModule(QValueModule): def __init__( self, - action_space: Optional[str], + action_space: str | None, support: torch.Tensor, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, - var_nums: Optional[int] = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, + var_nums: int | None = None, spec: TensorSpec = None, safe: bool = False, ): @@ -911,10 +909,10 @@ class QValueHook: def __init__( self, action_space: str, - var_nums: Optional[int] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + var_nums: int | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError( @@ -938,7 +936,7 @@ def __init__( def __call__( self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: kwargs = {self.action_value_key: values} return self.qvalue_model(**kwargs) @@ -1007,10 +1005,10 @@ def __init__( self, action_space: str, support: torch.Tensor, - var_nums: Optional[int] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + var_nums: int | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError("Using specs in action_space is deprecated") @@ -1125,9 +1123,9 @@ def __init__( in_keys=None, spec=None, safe=False, - action_space: Optional[str] = None, + action_space: str | None = None, action_value_key=None, - action_mask_key: Optional[NestedKey] = None, + action_mask_key: NestedKey | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError( @@ -1268,10 +1266,10 @@ def __init__( in_keys=None, spec=None, safe=False, - var_nums: Optional[int] = None, - action_space: Optional[str] = None, + var_nums: int | None = None, + action_space: str | None = None, action_value_key: str = "action_value", - action_mask_key: Optional[NestedKey] = None, + action_mask_key: NestedKey | None = None, make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): @@ -1836,7 +1834,7 @@ def __init__( policy: TensorDictModule, *, inference_context: int = 5, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, device: torch.device | None = None, ): super().__init__(policy) @@ -2066,7 +2064,7 @@ def __init__( high=None, clamp: bool = False, ): - super(TanhModule, self).__init__() + super().__init__() self.in_keys = in_keys if out_keys is None: out_keys = in_keys @@ -2291,8 +2289,8 @@ def __init__( actor: TensorDictModuleBase, n_steps: int, *, - action_keys: List[NestedKey] | None = None, - init_key: List[NestedKey] | None = None, + action_keys: list[NestedKey] | None = None, + init_key: list[NestedKey] | None = None, ): self.action_keys = action_keys self.init_key = init_key @@ -2387,7 +2385,7 @@ def forward( return tensordict @property - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: action_keys = self.__dict__.get("_action_keys", None) if action_keys is None: @@ -2411,7 +2409,7 @@ def action_keys(self, value): self._action_keys = [unravel_key(key) for key in value] @property - def _actor_keys_map(self) -> Dict[NestedKey, NestedKey]: + def _actor_keys_map(self) -> dict[NestedKey, NestedKey]: val = self.__dict__.get("_actor_keys_map_values", None) if val is None: diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 7914f663a45..f2b40e645e4 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,20 +9,15 @@ import inspect import re import warnings -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable import torch - from tensordict import TensorDictBase, unravel_key_list - -from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase +from tensordict.nn import TensorDictModule, TensorDictModuleBase, dispatch from tensordict.utils import NestedKey - from torch import nn from torch.nn import functional as F - from torchrl.data.tensor_specs import Composite, TensorSpec - from torchrl.data.utils import DEVICE_TYPING _has_functorch = importlib.util.find_spec("functorch") is not None @@ -194,12 +189,15 @@ class SafeModule(TensorDictModule): def __init__( self, - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], + module: ( + FunctionalModule + | FunctionalModuleWithBuffers + | TensorDictModule + | nn.Module + ), in_keys: Iterable[str], out_keys: Iterable[str], - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, safe: bool = False, ): super().__init__(module, in_keys, out_keys) @@ -282,14 +280,14 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: """See :obj:`TensorDictModule.random(...)`.""" return self.random(tensordict) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TensorDictModule: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> TensorDictModule: if hasattr(self, "spec") and self.spec is not None: self.spec = self.spec.to(dest) out = super().to(dest) return out -def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): +def is_tensordict_compatible(module: TensorDictModule | nn.Module): """Returns `True` if a module can be used as a TensorDictModule, and False if it can't. If the signature is misleading an error is raised. @@ -356,13 +354,13 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): def ensure_tensordict_compatible( - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], - in_keys: Optional[List[NestedKey]] = None, - out_keys: Optional[List[NestedKey]] = None, + module: ( + FunctionalModule | FunctionalModuleWithBuffers | TensorDictModule | nn.Module + ), + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, safe: bool = False, - wrapper_type: Optional[Type] = TensorDictModule, + wrapper_type: type | None = TensorDictModule, **kwargs, ): """Ensures module is compatible with TensorDictModule and, if not, it wraps it.""" diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 62ac9710cc5..5bbe40a725d 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -5,22 +5,19 @@ from __future__ import annotations import warnings -from typing import Optional, Union import numpy as np import torch from tensordict import TensorDictBase - from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import NestedKey, expand_as_right, expand_right from torch import nn - from torchrl.data.tensor_specs import Composite, TensorSpec -from torchrl.envs.utils import exploration_type, ExplorationType +from torchrl.envs.utils import ExplorationType, exploration_type from torchrl.modules.tensordict_module.common import _forward_hook_safe_action __all__ = [ @@ -95,8 +92,8 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, *, - action_key: Optional[NestedKey] = "action", - action_mask_key: Optional[NestedKey] = None, + action_key: NestedKey | None = "action", + action_mask_key: NestedKey | None = None, device: torch.device | None = None, ): if not isinstance(eps_init, float): @@ -209,9 +206,9 @@ def __init__( eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, - action_key: Optional[NestedKey] = "action", - action_mask_key: Optional[NestedKey] = None, - spec: Optional[TensorSpec] = None, + action_key: NestedKey | None = "action", + action_mask_key: NestedKey | None = None, + spec: TensorSpec | None = None, ): raise RuntimeError( "This class has been deprecated in favor of torchrl.modules.EGreedyModule." @@ -230,9 +227,9 @@ def __init__( annealing_num_steps: int = 1000, mean: float = 0.0, std: float = 1.0, - action_key: Optional[NestedKey] = "action", - spec: Optional[TensorSpec] = None, - safe: Optional[bool] = True, + action_key: NestedKey | None = "action", + spec: TensorSpec | None = None, + safe: bool | None = True, device: torch.device | None = None, ): raise RuntimeError( @@ -287,7 +284,7 @@ def __init__( mean: float = 0.0, std: float = 1.0, *, - action_key: Optional[NestedKey] = "action", + action_key: NestedKey | None = "action", # safe is already implemented because we project in the noise addition safe: bool = False, device: torch.device | None = None, @@ -383,14 +380,14 @@ def __init__( mu: float = 0.0, sigma: float = 0.2, dt: float = 1e-2, - x0: Optional[Union[torch.Tensor, np.ndarray]] = None, - sigma_min: Optional[float] = None, + x0: torch.Tensor | np.ndarray | None = None, + sigma_min: float | None = None, n_steps_annealing: int = 1000, - action_key: Optional[NestedKey] = "action", - is_init_key: Optional[NestedKey] = "is_init", + action_key: NestedKey | None = "action", + is_init_key: NestedKey | None = "is_init", spec: TensorSpec = None, safe: bool = True, - key: Optional[NestedKey] = None, + key: NestedKey | None = None, device: torch.device | None = None, ): raise RuntimeError( @@ -611,11 +608,11 @@ def __init__( mu: float = 0.0, sigma: float = 0.2, dt: float = 1e-2, - x0: Optional[Union[torch.Tensor, np.ndarray]] = None, - sigma_min: Optional[float] = None, + x0: torch.Tensor | np.ndarray | None = None, + sigma_min: float | None = None, n_steps_annealing: int = 1000, - key: Optional[NestedKey] = "action", - is_init_key: Optional[NestedKey] = "is_init", + key: NestedKey | None = "action", + is_init_key: NestedKey | None = "is_init", device: torch.device | None = None, ): super().__init__() @@ -688,7 +685,7 @@ def add_sample( self, tensordict: TensorDictBase, eps: float = 1.0, - is_init: Optional[torch.Tensor] = None, + is_init: torch.Tensor | None = None, ) -> TensorDictBase: # Get the nested tensordict where the action lives diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 79b0d015823..e03685d55d8 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -5,12 +5,10 @@ from __future__ import annotations import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List import torch - from tensordict import TensorDictBase, unravel_key_list - from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule, @@ -186,16 +184,16 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): def __init__( self, - in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey], - out_keys: NestedKey | List[NestedKey] | None = None, - spec: Optional[TensorSpec] = None, + in_keys: NestedKey | list[NestedKey] | dict[str, NestedKey], + out_keys: NestedKey | list[NestedKey] | None = None, + spec: TensorSpec | None = None, *, safe: bool = False, default_interaction_type: InteractionType = InteractionType.DETERMINISTIC, distribution_class: type = Delta, distribution_kwargs: dict | None = None, return_log_prob: bool = False, - log_prob_keys: List[NestedKey] | None = None, + log_prob_keys: list[NestedKey] | None = None, log_prob_key: NestedKey | None = None, cache_dist: bool = False, n_empirical_estimate: int = 1000, @@ -315,7 +313,7 @@ class SafeProbabilisticTensorDictSequential( def __init__( self, - *modules: Union[TensorDictModule, ProbabilisticTensorDictModule], + *modules: TensorDictModule | ProbabilisticTensorDictModule, partial_tolerant: bool = False, ) -> None: super().__init__(*modules, partial_tolerant=partial_tolerant) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 815756c528a..be23b985498 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -6,20 +6,16 @@ import typing import warnings -from typing import Any, Optional, Tuple +from typing import Any import torch import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list - from tensordict.base import NO_DEFAULT - -from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase +from tensordict.nn import TensorDictModuleBase as ModuleBase, dispatch from tensordict.utils import expand_as_right, prod, set_lazy_legacy - -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn.modules.rnn import RNNCellBase - from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( @@ -78,8 +74,8 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ) -> Tuple[Tensor, Tensor]: + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None + ) -> tuple[Tensor, Tensor]: if input.dim() not in (1, 2): raise ValueError( f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -795,16 +791,16 @@ def _lstm( steps, device, dtype, - hidden0_in: Optional[torch.Tensor] = None, - hidden1_in: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") if hidden1_in is None and hidden0_in is None: shape = (batch, steps) - hidden0_in, hidden1_in = [ + hidden0_in, hidden1_in = ( torch.zeros( *shape, self.lstm.num_layers, @@ -813,7 +809,7 @@ def _lstm( dtype=dtype, ) for _ in range(2) - ] + ) elif hidden1_in is None or hidden0_in is None: raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" @@ -887,7 +883,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -1606,8 +1602,8 @@ def _gru( steps, device, dtype, - hidden_in: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_in: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 65ef79d4606..369f1d56ad0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -7,38 +7,30 @@ import contextlib from copy import deepcopy from dataclasses import dataclass -from typing import Tuple import torch -from tensordict import ( - is_tensor_collection, - TensorDict, - TensorDictBase, - TensorDictParams, -) +from tensordict import (TensorDict, TensorDictBase, TensorDictParams, is_tensor_collection) from tensordict.nn import ( - composite_lp_aggregate, CompositeDistribution, - dispatch, ProbabilisticTensorDictSequential, - set_composite_lp_aggregate, TensorDictModule, + composite_lp_aggregate, + dispatch, + set_composite_lp_aggregate, ) from tensordict.utils import NestedKey from torch import distributions as d - from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( + ValueEstimators, + _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _clip_value_loss, - _GAMMA_LMBDA_DEPREC_ERROR, _get_default_device, _reduce, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import ( GAE, @@ -437,7 +429,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: @set_composite_lp_aggregate(False) def _log_probs( self, tensordict: TensorDictBase - ) -> Tuple[torch.Tensor, d.Distribution]: + ) -> tuple[torch.Tensor, d.Distribution]: # current log_prob of actions tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False @@ -466,7 +458,7 @@ def _log_probs( log_prob = log_prob.unsqueeze(-1) return log_prob, dist - def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: + def loss_critic(self, tensordict: TensorDictBase) -> tuple[torch.Tensor, float]: """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``. Returns the loss and the clip-fraction. diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index e4782195f4d..33b00ff969b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -10,11 +10,10 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple +from typing import Iterator import torch -from tensordict import is_tensor_collection, TensorDict, TensorDictBase - +from tensordict import TensorDict, TensorDictBase, is_tensor_collection from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from tensordict.utils import Buffer from torch import nn @@ -22,8 +21,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import set_recurrent_mode - -from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators +from torchrl.objectives.utils import ValueEstimators from torchrl.objectives.value import ValueEstimatorBase try: @@ -280,9 +278,9 @@ def convert_to_functional( self, module: TensorDictModule, module_name: str, - expand_dim: Optional[int] = None, + expand_dim: int | None = None, create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, + compare_against: list[Parameter] | None = None, **kwargs, ) -> None: """Converts a module to functional to be used in the loss. @@ -486,7 +484,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def named_parameters( self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, Parameter]]: + ) -> Iterator[tuple[str, Parameter]]: for name, param in super().named_parameters(prefix=prefix, recurse=recurse): if not name.startswith("_target"): yield name, param diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0e3a2447650..fd6c4231312 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -8,34 +8,30 @@ import warnings from copy import deepcopy from dataclasses import dataclass - -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey, unravel_key from torch import Tensor - from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.modules import ProbabilisticActor, QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) - from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -274,7 +270,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -282,7 +278,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", delay_actor: bool = False, delay_qvalue: bool = True, gamma: float = None, @@ -581,7 +577,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor: ) return bc_actor_loss, metadata - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -705,7 +701,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def q_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._get_value_v( tensordict.copy(), @@ -743,7 +739,7 @@ def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_qval, metadata - def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def cql_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: pred_q1 = tensordict.get(self.tensor_keys.pred_q1) pred_q2 = tensordict.get(self.tensor_keys.pred_q2) @@ -1089,9 +1085,9 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[QValueActor, nn.Module], + value_network: QValueActor | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, gamma: float = None, action_space=None, @@ -1218,7 +1214,7 @@ def in_keys(self, values): def value_loss( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: td_copy = tensordict.clone(False) with self.value_network_params.to_module(self.value_network): self.value_network(td_copy) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 45976c3c48f..4bf314bce4c 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -7,26 +7,23 @@ import math from dataclasses import dataclass from functools import wraps -from typing import Dict, List, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( + ValueEstimators, _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -256,7 +253,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -265,7 +262,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -559,7 +556,7 @@ def _cached_detached_qvalue_params(self): def actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which @@ -601,7 +598,7 @@ def actor_loss( def qvalue_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 34a7aa72242..6dd10de7e89 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -7,22 +7,20 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule - +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, _reduce, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -340,7 +338,7 @@ def loss_actor( def loss_value( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 05ade582d2a..03cdc42e667 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -6,16 +6,13 @@ import math from dataclasses import dataclass -from typing import Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey - from torch import distributions as d from torchrl.modules import ProbabilisticActor - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import _reduce, distance_loss @@ -85,7 +82,7 @@ def __init__( min_alpha: float = None, max_alpha: float = None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", samples_mc_entropy: int = 1, reduction: str = None, ) -> None: diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 5faadccfe93..a9eeae6d5d4 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -7,26 +7,18 @@ import math from dataclasses import dataclass from numbers import Number -from typing import List, Tuple, Union import numpy as np import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, composite_lp_aggregate, dispatch from tensordict.utils import NestedKey from torch import Tensor - from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp -from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators +from torchrl.objectives import ValueEstimators, default_value_kwargs, distance_loss from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import ( - _cache_values, - _GAMMA_LMBDA_DEPREC_ERROR, - _reduce, - _vmap_func, -) +from torchrl.objectives.utils import (_GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _reduce, _vmap_func) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -149,7 +141,7 @@ def __post_init__(self): def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -159,7 +151,7 @@ def __init__( max_alpha: float = 10.0, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, gSDE: bool = False, gamma: float = None, @@ -362,7 +354,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detach_qvalue_network_params(self): return self.qvalue_network_params.detach() - def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, Tensor]: obs_keys = self.actor_network.in_keys tensordict_clone = tensordict.select(*obs_keys, strict=False) with set_exploration_type( diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 21376bfd5b2..ad5b7711b5b 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -6,31 +6,27 @@ import warnings from dataclasses import dataclass -from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torch import nn from torchrl.data.tensor_specs import TensorSpec - from torchrl.data.utils import _find_action_space - from torchrl.envs.utils import step_mdp from torchrl.modules.tensordict_module.actors import ( DistributionalQValueActor, QValueActor, ) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TDLambdaEstimator from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator @@ -175,13 +171,13 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[QValueActor, nn.Module], + value_network: QValueActor | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, double_dqn: bool = False, gamma: float = None, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, priority_key: str = None, reduction: str = None, ) -> None: @@ -454,7 +450,7 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[DistributionalQValueActor, nn.Module], + value_network: DistributionalQValueActor | nn.Module, *, gamma: float, delay_value: bool = True, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index a8c439288eb..054b9de9d2c 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -5,25 +5,22 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey - from torchrl._utils import timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, - # distance_loss, hold_out_net, - ValueEstimators, -) +) # distance_loss, from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -103,8 +100,8 @@ def __init__( lambda_kl: float = 1.0, lambda_reco: float = 1.0, lambda_reward: float = 1.0, - reco_loss: Optional[str] = None, - reward_loss: Optional[str] = None, + reco_loss: str | None = None, + reward_loss: str | None = None, free_nats: int = 3, delayed_clamp: bool = False, global_average: bool = False, @@ -277,7 +274,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, ) - def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: + def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief).detach() with timeit("actor_loss/time-rollout"), hold_out_net( @@ -409,7 +406,7 @@ class _AcceptedKeys: def __init__( self, value_model: TensorDictModule, - value_loss: Optional[str] = None, + value_loss: str | None = None, discount_loss: bool = True, # for consistency with paper gamma: int = 0.99, ): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index ca1efcc337b..087204d1ad6 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -6,25 +6,23 @@ import warnings from dataclasses import dataclass -from typing import List, Optional, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space - from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -256,8 +254,8 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], - value_network: Optional[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], + value_network: TensorDictModule | None, *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -410,7 +408,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) return td_out - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -455,7 +453,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) @@ -478,7 +476,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys tensordict = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False @@ -769,9 +767,9 @@ def __init__( self, actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, - value_network: Optional[TensorDictModule], + value_network: TensorDictModule | None, *, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", temperature: float = 1.0, @@ -809,7 +807,7 @@ def __init__( self.action_space = _find_action_space(action_space) self.reduction = reduction - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -870,7 +868,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # Min Q value with torch.no_grad(): # Min Q value @@ -914,7 +912,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys next_td = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f3572cef9df..66307403283 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -8,30 +8,24 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torch import nn - from torchrl.data.tensor_specs import TensorSpec - from torchrl.data.utils import _find_action_space - from torchrl.modules import SafeSequential from torchrl.modules.tensordict_module.actors import QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible - from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TDLambdaEstimator from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator @@ -193,13 +187,13 @@ class _AcceptedKeys: def __init__( self, - local_value_network: Union[QValueActor, nn.Module], - mixer_network: Union[TensorDictModule, nn.Module], + local_value_network: QValueActor | nn.Module, + mixer_network: TensorDictModule | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, gamma: float = None, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, priority_key: str = None, ) -> None: super().__init__() diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 71455c83e2d..5cc5795584c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -6,44 +6,35 @@ import contextlib import warnings - from copy import deepcopy from dataclasses import dataclass -from typing import List, Tuple import torch -from tensordict import ( - is_tensor_collection, - TensorDict, - TensorDictBase, - TensorDictParams, -) +from tensordict import (TensorDict, TensorDictBase, TensorDictParams, is_tensor_collection) from tensordict.nn import ( - composite_lp_aggregate, CompositeDistribution, - dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, - set_composite_lp_aggregate, TensorDictModule, + composite_lp_aggregate, + dispatch, + set_composite_lp_aggregate, ) from tensordict.utils import NestedKey from torch import distributions as d - from torchrl._utils import _standardize from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( + ValueEstimators, + _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _clip_value_loss, - _GAMMA_LMBDA_DEPREC_ERROR, _maybe_add_or_extend_key, _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import ( GAE, @@ -298,11 +289,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey | List[NestedKey] | None = None - action: NestedKey | List[NestedKey] = "action" - reward: NestedKey | List[NestedKey] = "reward" - done: NestedKey | List[NestedKey] = "done" - terminated: NestedKey | List[NestedKey] = "terminated" + sample_log_prob: NestedKey | list[NestedKey] | None = None + action: NestedKey | list[NestedKey] = "action" + reward: NestedKey | list[NestedKey] = "reward" + done: NestedKey | list[NestedKey] = "done" + terminated: NestedKey | list[NestedKey] = "terminated" def __post_init__(self): if self.sample_log_prob is None: @@ -333,7 +324,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -521,7 +512,7 @@ def _get_entropy( def _log_weight( self, tensordict: TensorDictBase, adv_shape: torch.Size - ) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]: + ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: with self.actor_network_params.to_module( self.actor_network @@ -891,7 +882,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -902,7 +893,7 @@ def __init__( if isinstance(clip_value, bool): clip_value = clip_epsilon if clip_value else None - super(ClipPPOLoss, self).__init__( + super().__init__( actor_network, critic_network, entropy_bonus=entropy_bonus, @@ -1162,14 +1153,14 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, clip_value: float | None = None, **kwargs, ): - super(KLPENPPOLoss, self).__init__( + super().__init__( actor_network, critic_network, entropy_bonus=entropy_bonus, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index fd94404c2c1..fca6375a634 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -7,27 +7,23 @@ import math from dataclasses import dataclass from numbers import Number -from typing import List, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - -from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, composite_lp_aggregate, dispatch from tensordict.utils import NestedKey from torch import Tensor - from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -266,7 +262,7 @@ def __post_init__(self): def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -276,7 +272,7 @@ def __init__( max_alpha: float = 10.0, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, gSDE: bool = False, gamma: float = None, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 12e027c08f5..71e01a36883 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,20 +9,18 @@ from dataclasses import dataclass from functools import wraps from numbers import Number -from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - from tensordict.nn import ( - composite_lp_aggregate, CompositeDistribution, + TensorDictModule, + composite_lp_aggregate, dispatch, set_composite_lp_aggregate, - TensorDictModule, ) -from tensordict.utils import expand_right, NestedKey +from tensordict.utils import NestedKey, expand_right from torch import Tensor from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space @@ -30,15 +28,14 @@ from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -317,8 +314,8 @@ def __post_init__(self): def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], - value_network: Optional[TensorDictModule] = None, + qvalue_network: TensorDictModule | list[TensorDictModule], + value_network: TensorDictModule | None = None, *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -327,7 +324,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", delay_actor: bool = False, delay_qvalue: bool = True, delay_value: bool = True, @@ -653,7 +650,7 @@ def _cached_detached_qvalue_params(self): def _actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -693,7 +690,7 @@ def _cached_target_params_actor_value(self): def _qvalue_v1_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: target_params = self._cached_target_params_actor_value with set_exploration_type(self.deterministic_sampling_mode): target_value = self.value_estimator.value_estimate( @@ -808,7 +805,7 @@ def _compute_target_v2(self, tensordict) -> Tensor: def _qvalue_v2_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._compute_target_v2(tensordict) @@ -830,7 +827,7 @@ def _qvalue_v2_loss( def _value_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): @@ -1085,8 +1082,8 @@ def __init__( actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, *, - action_space: Union[str, TensorSpec] = None, - num_actions: Optional[int] = None, + action_space: str | TensorSpec = None, + num_actions: int | None = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -1094,7 +1091,7 @@ def __init__( max_alpha: float = None, fixed_alpha: bool = False, target_entropy_weight: float = 0.98, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, priority_key: str = None, separate_losses: bool = False, @@ -1338,7 +1335,7 @@ def _compute_target(self, tensordict) -> Tensor: def _value_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: target_value = self._compute_target(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), @@ -1376,7 +1373,7 @@ def _value_loss( def _actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # get probs and log probs for actions with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict.clone(False)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 124620ab040..707193ba653 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -5,26 +5,22 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec - from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( - _cache_values, + ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, + _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -226,10 +222,10 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, action_spec: TensorSpec = None, - bounds: Optional[Tuple[float]] = None, + bounds: tuple[float] | None = None, num_qvalue_nets: int = 2, policy_noise: float = 0.2, noise_clip: float = 0.5, @@ -373,7 +369,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -406,7 +402,7 @@ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def value_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 08c79bdffd4..b64ad0758b9 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -5,25 +5,21 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import TensorDictModule, dispatch from tensordict.utils import NestedKey from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec - from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( + ValueEstimators, _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, - ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -241,10 +237,10 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, action_spec: TensorSpec = None, - bounds: Optional[Tuple[float]] = None, + bounds: tuple[float] | None = None, num_qvalue_nets: int = 2, policy_noise: float = 0.2, noise_clip: float = 0.5, @@ -387,7 +383,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -441,7 +437,7 @@ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def qvalue_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 7ec4736862d..5c8c760ca5a 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,12 +8,12 @@ import re import warnings from enum import Enum -from typing import Iterable, List, Optional, Union +from typing import Iterable import torch from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as F from torch.nn.modules import dropout @@ -159,7 +159,7 @@ class TargetNetUpdater: def __init__( self, - loss_module: "LossModule", # noqa: F821 + loss_module: LossModule, # noqa: F821 ): from torchrl.objectives.common import LossModule @@ -284,7 +284,7 @@ def step(self) -> None: f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) for key, param in self._sources.items(): - target = self._targets.get("target_{}".format(key)) + target = self._targets.get(f"target_{key}") if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") self._step(param, target) @@ -320,16 +320,16 @@ class SoftUpdate(TargetNetUpdater): def __init__( self, - loss_module: Union[ - "DQNLoss", # noqa: F821 - "DDPGLoss", # noqa: F821 - "SACLoss", # noqa: F821 - "REDQLoss", # noqa: F821 - "TD3Loss", # noqa: F821 - ], + loss_module: ( + DQNLoss + | DDPGLoss # noqa: F821 + | SACLoss # noqa: F821 + | REDQLoss # noqa: F821 + | TD3Loss # noqa: F821 # noqa: F821 + ), *, eps: float = None, - tau: Optional[float] = None, + tau: float | None = None, ): if eps is None and tau is None: raise RuntimeError( @@ -350,7 +350,7 @@ def __init__( raise ValueError( f"Got eps = {eps} when it was supposed to be between 0 and 1." ) - super(SoftUpdate, self).__init__(loss_module) + super().__init__(loss_module) self.eps = eps def _step( @@ -375,11 +375,11 @@ class HardUpdate(TargetNetUpdater): def __init__( self, - loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss", "TD3Loss"], # noqa: F821 + loss_module: DQNLoss | DDPGLoss | SACLoss | TD3Loss, # noqa: F821 *, value_network_update_interval: float = 1000, ): - super(HardUpdate, self).__init__(loss_module) + super().__init__(loss_module) self.value_network_update_interval = value_network_update_interval self.counter = 0 @@ -441,10 +441,10 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @torch.no_grad() def next_state_value( tensordict: TensorDictBase, - operator: Optional[TensorDictModule] = None, + operator: TensorDictModule | None = None, next_val_key: str = "state_action_value", gamma: float = 0.99, - pred_next_val: Optional[Tensor] = None, + pred_next_val: Tensor | None = None, **kwargs, ) -> torch.Tensor: """Computes the next state value (without gradient) to compute a target value. @@ -550,7 +550,7 @@ def decorated_module(*module_args_params): ) from err -def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]: +def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor: """Reduces a tensor given the reduction method.""" if reduction == "none": result = tensor @@ -632,8 +632,8 @@ def _maybe_get_or_select(td, key_or_keys, target_shape=None): def _maybe_add_or_extend_key( - tensor_keys: List[NestedKey], - key_or_list_of_keys: NestedKey | List[NestedKey], + tensor_keys: list[NestedKey], + key_or_list_of_keys: NestedKey | list[NestedKey], prefix: NestedKey = None, ): if prefix is not None: diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a52b6e40d97..69bbcc91721 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -10,32 +10,25 @@ from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps -from typing import Callable, List, Union +from typing import Callable, Union import torch -from tensordict import is_tensor_collection, TensorDictBase +from tensordict import TensorDictBase, is_tensor_collection from tensordict.nn import ( + ProbabilisticTensorDictModule, + TensorDictModule, + TensorDictModuleBase, composite_lp_aggregate, dispatch, - ProbabilisticTensorDictModule, set_composite_lp_aggregate, set_skip_existing, - TensorDictModule, - TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type from tensordict.utils import NestedKey, unravel_key from torch import Tensor - from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp - -from torchrl.objectives.utils import ( - _maybe_get_or_select, - _vmap_func, - hold_out_net, - RANDOM_MODULE_LIST, -) +from torchrl.objectives.utils import (RANDOM_MODULE_LIST, _maybe_get_or_select, _vmap_func, hold_out_net) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -153,7 +146,7 @@ def __post_init__(self): default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - value_network: Union[TensorDictModule, Callable] + value_network: TensorDictModule | Callable _vmap_randomness = None @property @@ -1038,8 +1031,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, ) -> TensorDictBase: r"""Computes the TD(:math:`\lambda`) advantage given the data in tensordict. @@ -1307,8 +1300,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, time_dim: int | None = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -1646,8 +1639,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, time_dim: int | None = None, ) -> TensorDictBase: """Computes the V-Trace correction given the data in tensordict. diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 15e5d56d6bf..87bdf1c630d 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -5,10 +5,9 @@ from __future__ import annotations import math - import warnings from functools import wraps -from typing import Optional, Tuple, Union +from typing import Union import torch @@ -129,7 +128,7 @@ def generalized_advantage_estimate( terminated: torch.Tensor | None = None, *, time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Generalized advantage estimate of a trajectory. Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" @@ -271,8 +270,8 @@ def _fast_vec_gae( @_transpose_time def vec_generalized_advantage_estimate( - gamma: Union[float, torch.Tensor], - lmbda: Union[float, torch.Tensor], + gamma: float | torch.Tensor, + lmbda: float | torch.Tensor, state_value: torch.Tensor, next_state_value: torch.Tensor, reward: torch.Tensor, @@ -280,7 +279,7 @@ def vec_generalized_advantage_estimate( terminated: torch.Tensor | None = None, *, time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Vectorized Generalized advantage estimate of a trajectory. Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" @@ -382,7 +381,7 @@ def td0_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """TD(0) advantage estimate of a trajectory. Also known as bootstrapped Temporal Difference or one-step return. @@ -422,7 +421,7 @@ def td0_return_estimate( terminated: torch.Tensor | None = None, *, done: torch.Tensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: # noqa: D417 """TD(0) discounted return estimate of a trajectory. @@ -645,7 +644,7 @@ def vec_td1_return_estimate( reward, done: torch.Tensor, terminated: torch.Tensor | None = None, - rolling_gamma: Optional[bool] = None, + rolling_gamma: bool | None = None, time_dim: int = -2, ): """Vectorized TD(1) return estimate. @@ -970,7 +969,7 @@ def td_lambda_advantage_estimate( def _fast_td_lambda_return_estimate( - gamma: Union[torch.Tensor, float], + gamma: torch.Tensor | float, lmbda: float, next_state_value: torch.Tensor, reward: torch.Tensor, @@ -1035,7 +1034,7 @@ def vec_td_lambda_return_estimate( reward, done, terminated: torch.Tensor | None = None, - rolling_gamma: Optional[bool] = None, + rolling_gamma: bool | None = None, *, time_dim: int = -2, ): @@ -1277,11 +1276,11 @@ def vtrace_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, - rho_thresh: Union[float, torch.Tensor] = 1.0, - c_thresh: Union[float, torch.Tensor] = 1.0, + rho_thresh: float | torch.Tensor = 1.0, + c_thresh: float | torch.Tensor = 1.0, # not a kwarg because used directly time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Computes V-Trace off-policy actor critic targets. Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 745e0feabb1..ae5124fb11c 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -7,11 +7,10 @@ import os from collections import defaultdict from pathlib import Path -from typing import Dict, Optional, Sequence, Union +from typing import Sequence import tensordict.utils import torch - from tensordict import MemoryMappedTensor from torch import Tensor @@ -35,7 +34,7 @@ def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30): self.files = {} - def add_scalar(self, name: str, value: float, global_step: Optional[int] = None): + def add_scalar(self, name: str, value: float, global_step: int | None = None): if global_step is None: global_step = len(self.scalars[name]) value = float(value) @@ -50,7 +49,7 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) fd.write(",".join([str(global_step), str(value)]) + "\n") fd.flush() - def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs): + def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs): """Writes a video on a file on disk. The video format can be one of @@ -106,7 +105,7 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'." ) - def add_text(self, tag, text, global_step: Optional[int] = None): + def add_text(self, tag, text, global_step: int | None = None): if global_step is None: global_step = self.videos_counter[tag] self.videos_counter[tag] += 1 @@ -161,7 +160,7 @@ def __init__( super().__init__(exp_name=exp_name, log_dir=log_dir) self._has_imported_moviepy = False - def _create_experiment(self) -> "CSVExperiment": + def _create_experiment(self) -> CSVExperiment: """Creates a CSV experiment.""" log_dir = str(os.path.join(self.log_dir, self.exp_name)) return CSVExperiment( @@ -205,7 +204,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 1181d6e2d0d..b2a68964571 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -7,15 +7,12 @@ import importlib.util import math from copy import copy -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, Sequence import numpy as np import torch - from tensordict import NonTensorData, TensorDictBase - from tensordict.utils import NestedKey - from torchrl._utils import _can_be_pickled from torchrl.data import TensorSpec from torchrl.data.tensor_specs import NonTensor, Unbounded @@ -108,11 +105,11 @@ def __init__( self, logger: Logger, tag: str, - in_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, skip: int | None = None, - center_crop: Optional[int] = None, + center_crop: int | None = None, make_grid: bool | None = None, - out_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Sequence[NestedKey] | None = None, fps: int | None = None, **kwargs, ) -> None: @@ -239,7 +236,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return self._call(tensordict) - def dump(self, suffix: Optional[str] = None) -> None: + def dump(self, suffix: str | None = None) -> None: """Writes the video to the ``self.logger`` attribute. Calling ``dump`` when no image has been stored in a no-op. @@ -296,7 +293,7 @@ def __init__( out_file_base: str, skip_reset: bool = True, skip: int = 4, - in_keys: Optional[Sequence[str]] = None, + in_keys: Sequence[str] | None = None, ) -> None: if in_keys is None: in_keys = [] @@ -318,7 +315,7 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: self.td.append(_td) return next_tensordict - def dump(self, suffix: Optional[str] = None) -> None: + def dump(self, suffix: str | None = None) -> None: if suffix is None: tag = self.tag else: @@ -430,7 +427,7 @@ class PixelRenderTransform(Transform): def __init__( self, - out_keys: List[NestedKey] = None, + out_keys: list[NestedKey] = None, preproc: Callable[ [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor ] = None, @@ -544,7 +541,7 @@ def enabled(self) -> bool: """Whether the recorder is enabled.""" return self._enabled - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: out = super().set_container(container) if isinstance(self.parent, EnvBase): # Start the env if needed diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 65be247cd33..cb76a1bcd26 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -8,24 +8,18 @@ import abc import pathlib import warnings -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from copy import deepcopy from textwrap import indent -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Sequence, Tuple import numpy as np import torch.nn -from tensordict import pad, TensorDictBase +from tensordict import TensorDictBase, pad from tensordict.nn import TensorDictModule from tensordict.utils import expand_right from torch import nn, optim - -from torchrl._utils import ( - _CKPT_BACKEND, - KeyDependentDefaultDict, - logger as torchrl_logger, - VERBOSE, -) +from torchrl._utils import (KeyDependentDefaultDict, VERBOSE, _CKPT_BACKEND, logger as torchrl_logger) from torchrl.collectors.collectors import DataCollectorBase from torchrl.collectors.utils import split_trajectories from torchrl.data.replay_buffers import ( @@ -70,11 +64,11 @@ class TrainerHookBase: """An abstract hooking class for torchrl Trainer class.""" @abc.abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError @abc.abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError @abc.abstractmethod @@ -143,7 +137,7 @@ def __new__(cls, *args, **kwargs): # trackers cls._optim_count: int = 0 cls._collected_frames: int = 0 - cls._last_log: Dict[str, Any] = {} + cls._last_log: dict[str, Any] = {} cls._last_save: int = 0 cls.collected_frames = 0 cls._app_state = None @@ -156,16 +150,16 @@ def __init__( total_frames: int, frame_skip: int, optim_steps_per_batch: int, - loss_module: Union[LossModule, Callable[[TensorDictBase], TensorDictBase]], - optimizer: Optional[optim.Optimizer] = None, - logger: Optional[Logger] = None, + loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase], + optimizer: optim.Optimizer | None = None, + logger: Logger | None = None, clip_grad_norm: bool = True, clip_norm: float = None, progress_bar: bool = True, seed: int = None, save_trainer_interval: int = 10000, log_interval: int = 10000, - save_trainer_file: Optional[Union[str, pathlib.Path]] = None, + save_trainer_file: str | pathlib.Path | None = None, ) -> None: # objects @@ -248,7 +242,7 @@ def app_state(self): } return self._app_state - def state_dict(self) -> Dict: + def state_dict(self) -> dict: state = self._get_state() state_dict = OrderedDict( collector=self.collector.state_dict(), @@ -258,7 +252,7 @@ def state_dict(self) -> Dict: ) return state_dict - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: model_state_dict = state_dict["loss_module"] collector_state_dict = state_dict["collector"] @@ -296,7 +290,7 @@ def save_trainer(self, force_save: bool = False) -> None: if _save and self.save_trainer_file: self._save_trainer() - def load_from_file(self, file: Union[str, pathlib.Path], **kwargs) -> Trainer: + def load_from_file(self, file: str | pathlib.Path, **kwargs) -> Trainer: """Loads a file and its state-dict in the trainer. Keyword arguments are passed to the :func:`~torch.load` function. @@ -617,10 +611,10 @@ def __init__(self, keys: Sequence[str]): def __call__(self, batch: TensorDictBase) -> TensorDictBase: return batch.select(*self.keys) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer, name="select_keys") -> None: @@ -664,11 +658,11 @@ class ReplayBufferTrainer(TrainerHookBase): def __init__( self, replay_buffer: TensorDictReplayBuffer, - batch_size: Optional[int] = None, + batch_size: int | None = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, - max_dims: Optional[Sequence[int]] = None, + max_dims: Sequence[int] | None = None, ) -> None: self.replay_buffer = replay_buffer self.batch_size = batch_size @@ -704,7 +698,7 @@ def sample(self, batch: TensorDictBase) -> TensorDictBase: def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "replay_buffer": self.replay_buffer.state_dict(), } @@ -738,7 +732,7 @@ class OptimizerHook(TrainerHookBase): def __init__( self, optimizer: optim.Optimizer, - loss_components: Optional[Sequence[str]] = None, + loss_components: Sequence[str] | None = None, ): if loss_components is not None and not loss_components: raise ValueError( @@ -788,10 +782,10 @@ def __call__( return losses_td - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer, name="optimizer") -> None: @@ -838,7 +832,7 @@ def __init__( self, logname="r_training", log_pbar: bool = False, - reward_key: Union[str, tuple] = None, + reward_key: str | tuple = None, ): self.logname = logname self.log_pbar = log_pbar @@ -846,7 +840,7 @@ def __init__( reward_key = REWARD_KEY self.reward_key = reward_key - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: if ("collector", "mask") in batch.keys(True): return { self.logname: batch.get(self.reward_key)[ @@ -873,7 +867,7 @@ def __init__( self, logname="r_training", log_pbar: bool = False, - reward_key: Union[str, tuple] = None, + reward_key: str | tuple = None, ): warnings.warn( "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", @@ -971,7 +965,7 @@ def normalize_reward(self, tensordict: TensorDictBase) -> TensorDictBase: self._normalize_has_been_called = True return tensordict - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_reward_stats": deepcopy(self._reward_stats), "scale": self.scale, @@ -979,7 +973,7 @@ def state_dict(self) -> Dict[str, Any]: "_update_has_been_called": self._update_has_been_called, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for key, value in state_dict.items(): setattr(self, key, value) @@ -1126,10 +1120,10 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: raise RuntimeError("Sampled invalid steps") return td - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer: Trainer, name: str = "batch_subsampler"): @@ -1194,9 +1188,9 @@ def __init__( policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = ExplorationType.RANDOM, - log_keys: Optional[List[Union[str, Tuple[str]]]] = None, - out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, - suffix: Optional[str] = None, + log_keys: list[str | tuple[str]] | None = None, + out_keys: dict[str | tuple[str], str] | None = None, + suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None, ) -> None: @@ -1223,7 +1217,7 @@ def __init__( self.log_pbar = log_pbar @torch.inference_mode() - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: out = None if self._count % self.record_interval == 0: with set_exploration_type(self.exploration_type): @@ -1259,13 +1253,13 @@ def __call__(self, batch: TensorDictBase) -> Dict: self.environment.close() return out - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return { "_count": self._count, "recorder_state_dict": self.environment.state_dict(), } - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self._count = state_dict["_count"] self.environment.load_state_dict(state_dict["recorder_state_dict"]) @@ -1289,9 +1283,9 @@ def __init__( policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = ExplorationType.RANDOM, - log_keys: Optional[List[Union[str, Tuple[str]]]] = None, - out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, - suffix: Optional[str] = None, + log_keys: list[str | tuple[str]] | None = None, + out_keys: dict[str | tuple[str], str] | None = None, + suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None, ) -> None: @@ -1352,7 +1346,7 @@ def register(self, trainer: Trainer, name: str = "update_weights"): self, ) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {} def load_state_dict(self, state_dict) -> None: @@ -1385,7 +1379,7 @@ def __init__(self, frame_skip: int, log_pbar: bool = False): self.frame_skip = frame_skip self.log_pbar = log_pbar - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: if ("collector", "mask") in batch.keys(True): current_frames = ( batch.get(("collector", "mask")).sum().item() * self.frame_skip @@ -1402,7 +1396,7 @@ def register(self, trainer: Trainer, name: str = "count_frames_log"): self, ) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {"frame_count": self.frame_count} def load_state_dict(self, state_dict) -> None: @@ -1410,7 +1404,7 @@ def load_state_dict(self, state_dict) -> None: def _check_input_output_typehint( - func: Callable, input: Type | List[Type], output: Type + func: Callable, input: type | list[type], output: type ): # Placeholder for a function that checks the types input / output against expectations return diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index fc360114377..4c94177b880 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -743,7 +743,7 @@ def print_csv_files_in_folder(folder_path): csv_files.append(os.path.join(dirpath, file)) for csv_file in csv_files: output_str += f"File: {csv_file}\n" - with open(csv_file, "r") as f: + with open(csv_file) as f: for i, line in enumerate(f): if i == 10: break From 14dbdb21aec8c1d97f30d8a08586eed4e78bc627 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 17:15:39 +0000 Subject: [PATCH 02/10] Update [ghstack-poisoned] --- .pre-commit-config.yaml | 8 ++ build_tools/setup_helpers/extension.py | 2 +- sota-implementations/dreamer/dreamer.py | 2 + sota-implementations/redq/utils.py | 7 +- test/_utils_internal.py | 12 +- test/mocking_classes.py | 11 +- test/smoke_test_deps.py | 2 - test/test_collector.py | 26 +++- test/test_cost.py | 1 - test/test_exploration.py | 2 +- test/test_libs.py | 2 +- test/test_loggers.py | 9 +- test/test_modules.py | 2 +- test/test_rb.py | 5 +- test/test_specs.py | 3 +- test/test_tensordictmodules.py | 2 - test/test_transforms.py | 19 +-- torchrl/__init__.py | 14 -- torchrl/_utils.py | 4 +- torchrl/collectors/__init__.py | 10 -- torchrl/collectors/collectors.py | 17 ++- torchrl/collectors/distributed/__init__.py | 6 - torchrl/collectors/distributed/generic.py | 7 +- torchrl/collectors/distributed/ray.py | 5 +- torchrl/collectors/distributed/rpc.py | 2 +- torchrl/collectors/distributed/sync.py | 7 +- torchrl/data/__init__.py | 99 --------------- torchrl/data/datasets/__init__.py | 10 -- torchrl/data/datasets/d4rl.py | 3 +- torchrl/data/datasets/openx.py | 5 +- torchrl/data/map/__init__.py | 5 - torchrl/data/map/tdstorage.py | 11 +- torchrl/data/map/tree.py | 12 +- torchrl/data/postprocs/__init__.py | 2 - torchrl/data/replay_buffers/__init__.py | 46 ------- torchrl/data/replay_buffers/checkpointers.py | 10 +- torchrl/data/replay_buffers/storages.py | 15 ++- torchrl/data/replay_buffers/writers.py | 2 +- torchrl/data/rlhf/__init__.py | 10 -- torchrl/data/rlhf/prompt.py | 3 +- torchrl/data/rlhf/utils.py | 3 +- torchrl/envs/__init__.py | 120 ------------------ torchrl/envs/batched_envs.py | 19 ++- torchrl/envs/custom/__init__.py | 5 - torchrl/envs/custom/llm.py | 1 - torchrl/envs/gym_like.py | 3 +- torchrl/envs/libs/__init__.py | 24 ---- torchrl/envs/libs/dm_control.py | 3 +- torchrl/envs/libs/gym.py | 7 +- torchrl/envs/libs/meltingpot.py | 5 +- torchrl/envs/libs/openspiel.py | 3 +- torchrl/envs/libs/pettingzoo.py | 19 +-- torchrl/envs/libs/robohive.py | 8 +- torchrl/envs/libs/unity_mlagents.py | 3 +- torchrl/envs/libs/vmas.py | 11 +- torchrl/envs/model_based/__init__.py | 3 - torchrl/envs/transforms/__init__.py | 68 ---------- torchrl/envs/transforms/rlhf.py | 7 +- torchrl/envs/transforms/transforms.py | 2 - torchrl/envs/transforms/vip.py | 3 +- torchrl/modules/__init__.py | 88 ------------- torchrl/modules/distributions/__init__.py | 11 +- torchrl/modules/distributions/continuous.py | 3 +- .../modules/distributions/truncated_normal.py | 2 +- torchrl/modules/models/__init__.py | 44 ------- torchrl/modules/models/exploration.py | 3 +- torchrl/modules/models/models.py | 7 +- torchrl/modules/models/multiagent.py | 2 +- torchrl/modules/models/utils.py | 2 +- torchrl/modules/planners/__init__.py | 4 - torchrl/modules/tensordict_module/__init__.py | 44 ------- torchrl/modules/tensordict_module/actors.py | 7 +- torchrl/modules/tensordict_module/common.py | 3 +- .../modules/tensordict_module/exploration.py | 5 +- .../tensordict_module/probabilistic.py | 1 - torchrl/modules/tensordict_module/rnn.py | 5 +- torchrl/modules/utils/__init__.py | 9 +- torchrl/objectives/__init__.py | 30 ----- torchrl/objectives/a2c.py | 18 ++- torchrl/objectives/common.py | 5 +- torchrl/objectives/cql.py | 8 +- torchrl/objectives/crossq.py | 5 +- torchrl/objectives/ddpg.py | 7 +- torchrl/objectives/decision_transformer.py | 3 +- torchrl/objectives/deprecated.py | 12 +- torchrl/objectives/dqn.py | 5 +- torchrl/objectives/dreamer.py | 3 +- torchrl/objectives/iql.py | 5 +- torchrl/objectives/multiagent/__init__.py | 2 - torchrl/objectives/multiagent/qmixer.py | 7 +- torchrl/objectives/ppo.py | 18 ++- torchrl/objectives/redq.py | 7 +- torchrl/objectives/sac.py | 11 +- torchrl/objectives/td3.py | 7 +- torchrl/objectives/td3_bc.py | 5 +- torchrl/objectives/utils.py | 4 +- torchrl/objectives/value/__init__.py | 12 -- torchrl/objectives/value/advantages.py | 19 ++- torchrl/objectives/value/functional.py | 1 - torchrl/record/__init__.py | 3 - torchrl/record/loggers/__init__.py | 9 -- torchrl/record/recorder.py | 2 +- torchrl/trainers/__init__.py | 18 --- torchrl/trainers/helpers/__init__.py | 18 --- torchrl/trainers/trainers.py | 12 +- tutorials/sphinx-tutorials/coding_dqn.py | 1 - 106 files changed, 328 insertions(+), 906 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31b7c41f7b..37adaef7979 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,3 +41,11 @@ repos: hooks: - id: pyupgrade args: [--py38-plus] + + - repo: local + hooks: + - id: autoflake + name: autoflake + entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports + language: system + types: [python] diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 973238763fb..82df53d3af0 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -8,7 +8,7 @@ import platform import subprocess from pathlib import Path -from subprocess import CalledProcessError, STDOUT, check_output +from subprocess import CalledProcessError, check_output, STDOUT import torch from setuptools import Extension diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 4534f93a4f5..a39c8904916 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -11,6 +11,7 @@ import torch import torch.cuda import tqdm + from dreamer_utils import ( _default_device, dump_video, @@ -20,6 +21,7 @@ make_environments, make_replay_buffer, ) + # mixed precision training from torch.amp import GradScaler from torch.nn.utils import clip_grad_norm_ diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 2fc55cd853f..0528e2b809e 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -17,7 +17,8 @@ ) from torch import distributions as d, nn, optim from torch.optim.lr_scheduler import CosineAnnealingLR -from torchrl._utils import VERBOSE, logger as torchrl_logger + +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase from torchrl.data import ( LazyMemmapStorage, @@ -35,10 +36,12 @@ Compose, DMControlEnv, DoubleToFloat, + env_creator, EnvBase, EnvCreator, FlattenObservation, GrayScale, + gSDENoise, GymEnv, InitTracker, NoopResetEnv, @@ -50,8 +53,6 @@ ToTensorImage, TransformedEnv, VecNorm, - env_creator, - gSDENoise, ) from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 89762567529..05fdada16d2 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -17,15 +17,21 @@ import pytest import torch import torch.cuda -from tensordict import NestedKey, TensorDict, TensorDictBase, tensorclass +from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase from tensordict.nn import TensorDictModuleBase from torch import nn, vmap -from torchrl._utils import (RL_WARNINGS, implement_for, logger as torchrl_logger, seed_generator) + +from torchrl._utils import ( + implement_for, + logger as torchrl_logger, + RL_WARNINGS, + seed_generator, +) from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool -from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend +from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv from torchrl.envs.transforms import ( Compose, RewardClipping, diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 2b77b45df40..5f03e773591 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -10,9 +10,10 @@ import numpy as np import torch import torch.nn as nn -from tensordict import TensorDict, TensorDictBase, tensorclass +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import TensorDictModuleBase -from tensordict.utils import NestedKey, expand_right +from tensordict.utils import expand_right, NestedKey + from torchrl.data import ( Binary, Bounded, @@ -28,7 +29,11 @@ from torchrl.envs import Transform from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase -from torchrl.envs.utils import (MarlGroupMapType, _terminated_or_truncated, check_marl_grouping) +from torchrl.envs.utils import ( + _terminated_or_truncated, + check_marl_grouping, + MarlGroupMapType, +) spec_dict = { "bounded": Bounded, diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index a803707408c..d6133ed1a64 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -9,8 +9,6 @@ import pytest -from torchrl.envs.libs.gym import gym_backend - def test_dm_control(): import dm_control # noqa: F401 diff --git a/test/test_collector.py b/test/test_collector.py index 523f851b37d..0ee85798f7d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -17,7 +17,13 @@ import pytest import torch from packaging import version -from tensordict import (LazyStackedTensorDict, NonTensorData, TensorDict, TensorDictBase, assert_allclose_td) +from tensordict import ( + assert_allclose_td, + LazyStackedTensorDict, + NonTensorData, + TensorDict, + TensorDictBase, +) from tensordict.nn import ( CudaGraphModule, TensorDictModule, @@ -25,6 +31,7 @@ TensorDictSequential, ) from torch import nn + from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -32,8 +39,12 @@ prod, seed_generator, ) -from torchrl.collectors import SyncDataCollector, aSyncDataCollector -from torchrl.collectors.collectors import (MultiSyncDataCollector, MultiaSyncDataCollector, _Interruptor) +from torchrl.collectors import aSyncDataCollector, SyncDataCollector +from torchrl.collectors.collectors import ( + _Interruptor, + MultiaSyncDataCollector, + MultiSyncDataCollector, +) from torchrl.collectors.utils import split_trajectories from torchrl.data import ( Composite, @@ -54,9 +65,14 @@ StepCounter, Transform, ) -from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend, set_gym_backend +from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend from torchrl.envs.transforms import TransformedEnv, VecNorm -from torchrl.envs.utils import (PARTIAL_MISSING_ERR, RandomPolicy, _aggregate_end_of_traj, check_env_specs) +from torchrl.envs.utils import ( + _aggregate_end_of_traj, + check_env_specs, + PARTIAL_MISSING_ERR, + RandomPolicy, +) from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule if os.getenv("PYTORCH_TEST_FBCODE"): diff --git a/test/test_cost.py b/test/test_cost.py index 33f7cfb0e80..eda6596e8e4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5859,7 +5859,6 @@ def test_crossq_tensordict_keys(self, td_est): actor = self._create_mock_actor() qvalue = self._create_mock_qvalue() - value = None loss_fn = CrossQLoss( actor_network=actor, diff --git a/test/test_exploration.py b/test/test_exploration.py index 10a4938c6fd..554ebdc4aca 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -381,7 +381,7 @@ def test_nested( action_spec = env.action_spec d_act = action_spec.shape[-1] - net = nn.LazyLinear(d_act).to(device) + nn.LazyLinear(d_act).to(device) policy = TensorDictModule( CountingEnvCountModule(action_spec=action_spec), in_keys=[("data", "states") if nested_obs_action else "observation"], diff --git a/test/test_libs.py b/test/test_libs.py index 57358b732a4..97f4c0f9fe7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -195,7 +195,7 @@ def get_gym_pixel_wrapper(): PixelObservationWrapper = gym_backend( "wrappers.pixel_observation" ).PixelObservationWrapper - except Exception as err: + except Exception: from torchrl.envs.libs.utils import ( GymPixelObservationWrapper as PixelObservationWrapper, ) diff --git a/test/test_loggers.py b/test/test_loggers.py index cf438feb9f1..87003250d8d 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -14,11 +14,12 @@ import pytest import torch from tensordict import MemoryMappedTensor -from torchrl.envs import GymEnv, ParallelEnv, check_env_specs + +from torchrl.envs import check_env_specs, GymEnv, ParallelEnv from torchrl.record.loggers.csv import CSVLogger -from torchrl.record.loggers.mlflow import MLFlowLogger, _has_mlflow, _has_tv -from torchrl.record.loggers.tensorboard import TensorboardLogger, _has_tb -from torchrl.record.loggers.wandb import WandbLogger, _has_wandb +from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger +from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger +from torchrl.record.loggers.wandb import _has_wandb, WandbLogger from torchrl.record.recorder import PixelRenderTransform, VideoRecorder if _has_tv: diff --git a/test/test_modules.py b/test/test_modules.py index f661fa6199d..63dda533a90 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -939,7 +939,7 @@ def test_multiagent_mlp_tdparams( else: return mlp = nn.Sequential(mlp) - mlp_device = mlp.to(device) + mlp.to(device) param_set = set(mlp.parameters()) for p in mlp[0].params.values(True, True): assert p in param_set diff --git a/test/test_rb.py b/test/test_rb.py index 70f254a47aa..81b4ab3759f 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1626,7 +1626,6 @@ def test_extend(self, rbtype, storage, size, prefetch): rb.extend(data) length = len(rb) for d in data[-length:]: - found_similar = False for b in rb._storage: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -1657,7 +1656,6 @@ def test_sample(self, rbtype, storage, size, prefetch): new_data = new_data[0] for d in new_data: - found_similar = False for b in data: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -2930,7 +2928,6 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): index = rb.extend(data) rb.update_priority(index, data["priority"]) found_traj_0 = False - found_traj_4_truncated_left = False found_traj_4_truncated_right = False for i, s in enumerate(rb): t = s["traj"].unique().tolist() @@ -2942,7 +2939,7 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): if s["step_count"][0] > 10: found_traj_4_truncated_right = True if s["step_count"][0] == 0: - found_traj_4_truncated_left = True + pass if i == 1000: break assert not rb._sampler.span[0] diff --git a/test/test_specs.py b/test/test_specs.py index 142aa128c6e..22eb23e82ea 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -10,6 +10,7 @@ import numpy as np import pytest import torch + import torchrl.data.tensor_specs from scipy.stats import chisquare from tensordict import ( @@ -22,6 +23,7 @@ from tensordict.utils import _unravel_key_to_tuple, set_capture_non_tensor_stack from torchrl._utils import _make_ordinal_device from torchrl.data.tensor_specs import ( + _keys_to_empty_composite_spec, Binary, BinaryDiscreteTensorSpec, Bounded, @@ -47,7 +49,6 @@ UnboundedContinuousTensorSpec, UnboundedDiscrete, UnboundedDiscreteTensorSpec, - _keys_to_empty_composite_spec, ) from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 7ba45fb8587..9400f111ece 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1639,8 +1639,6 @@ def test_batched_actor_exceptions(self): with pytest.raises(ValueError, match="Only a single init_key can be passed"): MultiStepActorWrapper(actor_base, n_steps=time_steps, init_key=["init_key"]) - n_obs = 1 - n_action = 1 batch = 2 # The second env has frequent resets, the first none diff --git a/test/test_transforms.py b/test/test_transforms.py index 07b103ef996..ffa81fde061 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -442,8 +442,6 @@ def test_transform_inverse(self): class TestClipTransform(TransformBase): @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): - device = "cpu" - batch = [20] torch.manual_seed(0) rb = rbclass(storage=LazyTensorStorage(20)) @@ -1271,7 +1269,7 @@ def test_catframes_reset(self, device): buffer = getattr(cat_frames, f"_cat_buffers_{key1}") tdc = td.clone() - passed_back_td = cat_frames._reset(tdc, tdc) + cat_frames._reset(tdc, tdc) # assert tdc is passed_back_td # assert (buffer == 0).all() @@ -4787,7 +4785,7 @@ def make_env(): ) return env - env = SerialEnv(2, make_env) + SerialEnv(2, make_env) def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): @@ -8544,7 +8542,6 @@ def test_transform_model(self): keys = [key1, key2] dim = -2 d = 4 - N = 3 batch_size = (5,) extra_d = (3,) * (-dim - 1) device = "cpu" @@ -8570,7 +8567,6 @@ def test_transform_rb(self, rbclass): keys = [key1, key2] dim = -2 d = 4 - N = 3 batch_size = (5,) extra_d = (3,) * (-dim - 1) device = "cpu" @@ -8594,7 +8590,6 @@ def test_transform_rb(self, rbclass): def test_tmp_reset(self, device): key1 = "first key" key2 = "second key" - N = 4 keys = [key1, key2] key1_tensor = torch.randn(1, 1, 3, 3, device=device) key2_tensor = torch.randn(1, 1, 3, 3, device=device) @@ -8606,7 +8601,7 @@ def test_tmp_reset(self, device): buffer = getattr(t, f"_maxpool_buffer_{key1}") tdc = td.clone() - passed_back_td = t._reset(tdc, tdc.empty()) + t._reset(tdc, tdc.empty()) # assert tdc is passed_back_td assert (buffer != 0).any() @@ -11771,7 +11766,7 @@ def test_transform_env(self): def test_transform_model(self): t = Compose(DeviceCastTransform("cpu:1", "cpu:0")) - m = nn.Sequential(t) + nn.Sequential(t) assert t(TensorDict(device="cpu:0")).device == torch.device("cpu:1") @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) @@ -11892,7 +11887,7 @@ def test_transform_env(self): # check error with pytest.raises(ValueError, match="Only tailing dims with negative"): - t = PermuteTransform((-1, -10)) + PermuteTransform((-1, -10)) def test_transform_model(self): batch = [2] @@ -12217,7 +12212,7 @@ def test_transform_env(self): RuntimeError, match="BurnInTransform can only be appended to a ReplayBuffer.", ): - rollout = env.rollout(3) + env.rollout(3) @pytest.mark.parametrize("module", ["gru", "lstm"]) @pytest.mark.parametrize("batch_size", [2, 4]) @@ -12545,7 +12540,7 @@ def test_trans_serial_env_check(self): with pytest.raises( RuntimeError, match="The environment passed to SerialEnv has empty specs" ): - env = TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) + TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): with pytest.raises( diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 42819ef1d8c..81e97018596 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -46,20 +46,6 @@ ) -import torchrl.collectors -import torchrl.data -import torchrl.envs -import torchrl.modules -import torchrl.objectives -import torchrl.trainers -from torchrl._utils import ( - auto_unwrap_transformed_env, - compile_with_warmup, - implement_for, - set_auto_unwrap_transformed_env, - timeit, -) - # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. filter_warnings_subprocess = True diff --git a/torchrl/_utils.py b/torchrl/_utils.py index f614b7f04a0..3258cb953eb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -21,14 +21,14 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, Tuple, TypeVar, cast +from typing import Any, Callable, cast, TypeVar import numpy as np import torch from packaging.version import parse from tensordict import unravel_key from tensordict.utils import NestedKey -from torch import Tensor, multiprocessing as mp +from torch import multiprocessing as mp, Tensor try: from torch.compiler import is_compiling diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index d69d8c9e50c..7bec24cb17b 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -2,13 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from torchrl.envs.utils import RandomPolicy - -from .collectors import ( - aSyncDataCollector, - DataCollectorBase, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d18e2d9a25b..67dff40e9de 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -15,7 +15,7 @@ import time import typing import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict, OrderedDict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -38,31 +38,32 @@ from torch import multiprocessing as mp from torch.nn import Parameter from torch.utils.data import IterableDataset + from torchrl._utils import ( - RL_WARNINGS, - VERBOSE, - _ProcessNoWarn, _check_for_faulty_process, _ends_with, _make_ordinal_device, + _ProcessNoWarn, _replace_last, accept_remote_rref_udf_invocation, compile_with_warmup, logger as torchrl_logger, prod, + RL_WARNINGS, + VERBOSE, ) from torchrl.collectors.utils import split_trajectories from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import EnvBase, _do_nothing +from torchrl.envs.common import _do_nothing, EnvBase from torchrl.envs.env_creator import EnvCreator from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( - ExplorationType, - RandomPolicy, _aggregate_end_of_traj, _make_compatible_policy, + ExplorationType, + RandomPolicy, set_exploration_type, ) @@ -120,8 +121,6 @@ class _InterruptorManager(SyncManager): between processes. """ - pass - _InterruptorManager.register("_Interruptor", _Interruptor) diff --git a/torchrl/collectors/distributed/__init__.py b/torchrl/collectors/distributed/__init__.py index 97932619a60..7bec24cb17b 100644 --- a/torchrl/collectors/distributed/__init__.py +++ b/torchrl/collectors/distributed/__init__.py @@ -2,9 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .generic import DEFAULT_SLURM_CONF, DistributedDataCollector -from .ray import RayCollector -from .rpc import RPCDataCollector -from .sync import DistributedSyncDataCollector -from .utils import submitit_delayed_launcher diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index df0edc5d5d3..5ec55e23a16 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -11,16 +11,17 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import Callable, List, OrderedDict +from typing import Callable, OrderedDict import torch.cuda from tensordict import TensorDict from torch import nn -from torchrl._utils import VERBOSE, _ProcessNoWarn, logger as torchrl_logger + +from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DEFAULT_EXPLORATION_TYPE, DataCollectorBase, + DEFAULT_EXPLORATION_TYPE, MultiSyncDataCollector, SyncDataCollector, ) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 46947f51d5d..1716609026b 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -6,16 +6,17 @@ from __future__ import annotations import warnings -from typing import Callable, Iterator, List, OrderedDict +from typing import Callable, Iterator, OrderedDict import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase + from torchrl._utils import logger as torchrl_logger from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DEFAULT_EXPLORATION_TYPE, DataCollectorBase, + DEFAULT_EXPLORATION_TYPE, MultiSyncDataCollector, SyncDataCollector, ) diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 2575469a7bb..ee73cfdf4e7 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -12,7 +12,7 @@ import time import warnings from copy import copy, deepcopy -from typing import Callable, List, OrderedDict +from typing import Callable, OrderedDict from torchrl._utils import logger as torchrl_logger from torchrl.collectors.distributed import DEFAULT_SLURM_CONF diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 17c54279e4b..0a2215e0abe 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -11,16 +11,17 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import Callable, List, OrderedDict +from typing import Callable, OrderedDict import torch.cuda from tensordict import TensorDict from torch import nn -from torchrl._utils import VERBOSE, _ProcessNoWarn, logger as torchrl_logger + +from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( - DEFAULT_EXPLORATION_TYPE, DataCollectorBase, + DEFAULT_EXPLORATION_TYPE, MultiSyncDataCollector, SyncDataCollector, ) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 33353b93b5f..7bec24cb17b 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -2,102 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .map import ( - BinaryToDecimal, - HashToInt, - MCTSForest, - QueryModule, - RandomProjectionHash, - SipHash, - TensorDictMap, - TensorMap, - Tree, -) -from .postprocs import MultiStep -from .replay_buffers import ( - Flat2TED, - FlatStorageCheckpointer, - H5Combine, - H5Split, - H5StorageCheckpointer, - ImmutableDatasetWriter, - LazyMemmapStorage, - LazyStackStorage, - LazyTensorStorage, - ListStorage, - ListStorageCheckpointer, - Nested2TED, - NestedStorageCheckpointer, - PrioritizedReplayBuffer, - PrioritizedSampler, - PrioritizedSliceSampler, - RandomSampler, - RemoteTensorDictReplayBuffer, - ReplayBuffer, - ReplayBufferEnsemble, - RoundRobinWriter, - SamplerEnsemble, - SamplerWithoutReplacement, - SliceSampler, - SliceSamplerWithoutReplacement, - Storage, - StorageCheckpointerBase, - StorageEnsemble, - StorageEnsembleCheckpointer, - TED2Flat, - TED2Nested, - TensorDictMaxValueWriter, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, - TensorDictRoundRobinWriter, - TensorStorage, - TensorStorageCheckpointer, - Writer, - WriterEnsemble, -) -from .rlhf import ( - AdaptiveKLController, - ConstantKLController, - create_infinite_iterator, - get_dataloader, - PairwiseDataset, - PromptData, - PromptTensorDictTokenizer, - RewardData, - RolloutFromModel, - TensorDictTokenizer, - TokenizedDatasetLoader, -) -from .tensor_specs import ( - Binary, - BinaryDiscreteTensorSpec, - Bounded, - BoundedContinuous, - BoundedTensorSpec, - Categorical, - Choice, - Composite, - CompositeSpec, - DEVICE_TYPING, - DiscreteTensorSpec, - LazyStackedCompositeSpec, - LazyStackedTensorSpec, - MultiCategorical, - MultiDiscreteTensorSpec, - MultiOneHot, - MultiOneHotDiscreteTensorSpec, - NonTensor, - NonTensorSpec, - OneHot, - OneHotDiscreteTensorSpec, - Stacked, - StackedComposite, - TensorSpec, - Unbounded, - UnboundedContinuous, - UnboundedContinuousTensorSpec, - UnboundedDiscrete, - UnboundedDiscreteTensorSpec, -) -from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index d099a3a1be5..7bec24cb17b 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -2,13 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .atari_dqn import AtariDQNExperienceReplay -from .common import BaseDatasetExperienceReplay -from .d4rl import D4RLExperienceReplay -from .gen_dgrl import GenDGRLExperienceReplay -from .minari_data import MinariExperienceReplay -from .openml import OpenMLExperienceReplay -from .openx import OpenXExperienceReplay -from .roboset import RobosetExperienceReplay -from .vd4rl import VD4RLExperienceReplay diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index e8b47619bce..5ad01ebfb59 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -15,7 +15,8 @@ import numpy as np import torch -from tensordict import PersistentTensorDict, TensorDict, make_tensordict +from tensordict import make_tensordict, PersistentTensorDict, TensorDict + from torchrl._utils import logger as torchrl_logger from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.common import BaseDatasetExperienceReplay diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index cf81f44cf56..607c68d84f9 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -14,8 +14,9 @@ from typing import Any, Callable import torch -from tensordict import NonTensorData, TensorDict, make_tensordict, pad +from tensordict import make_tensordict, NonTensorData, pad, TensorDict from tensordict.utils import _is_non_tensor + from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer @@ -24,7 +25,7 @@ SliceSampler, SliceSamplerWithoutReplacement, ) -from torchrl.data.replay_buffers.storages import Storage, TensorStorage, _collate_id +from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer _has_datasets = importlib.util.find_spec("datasets", None) is not None diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index c9bc25477c2..7bec24cb17b 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -2,8 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .hash import BinaryToDecimal, RandomProjectionHash, SipHash -from .query import HashToInt, QueryModule -from .tdstorage import TensorDictMap, TensorMap -from .tree import MCTSForest, Tree diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index f510f1d5dbe..f1464308144 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -7,14 +7,19 @@ import abc import functools from abc import abstractmethod -from typing import Any, Callable, Dict, Generic, List, TypeVar +from typing import Any, Callable, Generic, TypeVar import torch -from tensordict import NestedKey, TensorDictBase, is_tensor_collection +from tensordict import is_tensor_collection, NestedKey, TensorDictBase from tensordict.nn.common import TensorDictModuleBase + from torchrl.data.map.hash import RandomProjectionHash, SipHash from torchrl.data.map.query import QueryModule -from torchrl.data.replay_buffers.storages import (LazyTensorStorage, TensorStorage, _get_default_collate) +from torchrl.data.replay_buffers.storages import ( + _get_default_collate, + LazyTensorStorage, + TensorStorage, +) K = TypeVar("K") V = TypeVar("V") diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index ca8fc683d7f..dfb87223435 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -6,10 +6,18 @@ import weakref from collections import deque -from typing import Any, Callable, Dict, Literal, Tuple +from typing import Any, Callable, Literal import torch -from tensordict import (NestedKey, TensorClass, TensorDict, TensorDictBase, merge_tensordicts, unravel_key) +from tensordict import ( + merge_tensordicts, + NestedKey, + TensorClass, + TensorDict, + TensorDictBase, + unravel_key, +) + from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage diff --git a/torchrl/data/postprocs/__init__.py b/torchrl/data/postprocs/__init__.py index 707740f6946..7bec24cb17b 100644 --- a/torchrl/data/postprocs/__init__.py +++ b/torchrl/data/postprocs/__init__.py @@ -2,5 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .postprocs import MultiStep diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 4f230f30701..7bec24cb17b 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -2,49 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .checkpointers import ( - FlatStorageCheckpointer, - H5StorageCheckpointer, - ListStorageCheckpointer, - NestedStorageCheckpointer, - StorageCheckpointerBase, - StorageEnsembleCheckpointer, - TensorStorageCheckpointer, -) -from .replay_buffers import ( - PrioritizedReplayBuffer, - RemoteTensorDictReplayBuffer, - ReplayBuffer, - ReplayBufferEnsemble, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) -from .samplers import ( - PrioritizedSampler, - PrioritizedSliceSampler, - RandomSampler, - Sampler, - SamplerEnsemble, - SamplerWithoutReplacement, - SliceSampler, - SliceSamplerWithoutReplacement, -) -from .storages import ( - LazyMemmapStorage, - LazyStackStorage, - LazyTensorStorage, - ListStorage, - Storage, - StorageEnsemble, - TensorStorage, -) -from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested -from .writers import ( - ImmutableDatasetWriter, - RoundRobinWriter, - TensorDictMaxValueWriter, - TensorDictRoundRobinWriter, - Writer, - WriterEnsemble, -) diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 4e06756efc2..b545fd92227 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -11,17 +11,23 @@ import numpy as np import torch -from tensordict import (NonTensorData, PersistentTensorDict, TensorDict, is_tensor_collection) +from tensordict import ( + is_tensor_collection, + NonTensorData, + PersistentTensorDict, + TensorDict, +) from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE + from torchrl.data.replay_buffers.utils import ( + _save_pytree, Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested, - _save_pytree, ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index bdbeb983c2c..3515b8eb8e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -17,12 +17,18 @@ import numpy as np import tensordict import torch -from tensordict import (LazyStackedTensorDict, TensorDict, TensorDictBase, is_tensor_collection) +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDict, + TensorDictBase, +) from tensordict.base import _NESTED_TENSORS_AS_LISTS from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( ListStorageCheckpointer, @@ -30,7 +36,12 @@ StorageEnsembleCheckpointer, TensorStorageCheckpointer, ) -from torchrl.data.replay_buffers.utils import (INT_CLASSES, _init_pytree, _is_int, tree_iter) +from torchrl.data.replay_buffers.utils import ( + _init_pytree, + _is_int, + INT_CLASSES, + tree_iter, +) class Storage: diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index adb82d54b2e..f7fd2a5eef2 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -15,7 +15,7 @@ import numpy as np import torch -from tensordict import MemoryMappedTensor, TensorDictBase, is_tensor_collection +from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase from tensordict.utils import _STRDTYPE2DTYPE, expand_as_right, is_tensorclass from torch import multiprocessing as mp diff --git a/torchrl/data/rlhf/__init__.py b/torchrl/data/rlhf/__init__.py index f0db092f2d1..7bec24cb17b 100644 --- a/torchrl/data/rlhf/__init__.py +++ b/torchrl/data/rlhf/__init__.py @@ -2,13 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .dataset import ( - create_infinite_iterator, - get_dataloader, - TensorDictTokenizer, - TokenizedDatasetLoader, -) -from .prompt import PromptData, PromptTensorDictTokenizer -from .reward import PairwiseDataset, RewardData -from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index 3b9755db69a..8d6cfe54066 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -5,7 +5,8 @@ from __future__ import annotations import torch -from tensordict import TensorDict, tensorclass +from tensordict import tensorclass, TensorDict + from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader DEFAULT_DATASET = "CarperAI/openai_summarize_tldr" diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 01d18ecec11..a816c984062 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -11,8 +11,9 @@ import numpy as np import torch from tensordict import TensorDict -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F + from torchrl.data.rlhf.prompt import PromptData _has_transformers = importlib.util.find_spec("transformers") is not None diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d9753eafc08..7bec24cb17b 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -2,123 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .batched_envs import ParallelEnv, SerialEnv -from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv -from .env_creator import env_creator, EnvCreator, get_env_metadata -from .gym_like import default_info_dict_reader, GymLikeEnv -from .libs import ( - BraxEnv, - BraxWrapper, - DMControlEnv, - DMControlWrapper, - gym_backend, - GymEnv, - GymWrapper, - HabitatEnv, - IsaacGymEnv, - IsaacGymWrapper, - JumanjiEnv, - JumanjiWrapper, - MeltingpotEnv, - MeltingpotWrapper, - MOGymEnv, - MOGymWrapper, - MultiThreadedEnv, - MultiThreadedEnvWrapper, - OpenMLEnv, - OpenSpielEnv, - OpenSpielWrapper, - PettingZooEnv, - PettingZooWrapper, - register_gym_spec_conversion, - RoboHiveEnv, - set_gym_backend, - SMACv2Env, - SMACv2Wrapper, - UnityMLAgentsEnv, - UnityMLAgentsWrapper, - VmasEnv, - VmasWrapper, -) -from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase -from .transforms import ( - ActionDiscretizer, - ActionMask, - AutoResetEnv, - AutoResetTransform, - BatchSizeTransform, - BinarizeReward, - BurnInTransform, - CatFrames, - CatTensors, - CenterCrop, - ClipTransform, - Compose, - ConditionalSkip, - Crop, - DataLoadingPrimer, - DeviceCastTransform, - DiscreteActionProjection, - DoubleToFloat, - DTypeCastTransform, - EndOfLifeTransform, - ExcludeTransform, - FiniteTensorDictCheck, - FlattenObservation, - FrameSkipTransform, - GrayScale, - gSDENoise, - Hash, - InitTracker, - KLRewardTransform, - LineariseRewards, - MultiAction, - MultiStepTransform, - NoopResetEnv, - ObservationNorm, - ObservationTransform, - PermuteTransform, - PinMemoryTransform, - R3MTransform, - RandomCropTensorDict, - RemoveEmptySpecs, - RenameTransform, - Resize, - Reward2GoTransform, - RewardClipping, - RewardScaling, - RewardSum, - SelectTransform, - SignTransform, - SqueezeTransform, - Stack, - StepCounter, - TargetReturn, - TensorDictPrimer, - TimeMaxPool, - Timer, - Tokenizer, - ToTensorImage, - TrajCounter, - Transform, - TransformedEnv, - UnaryTransform, - UnsqueezeTransform, - VC1Transform, - VecGymEnvTransform, - VecNorm, - VIPRewardTransform, - VIPTransform, -) -from .utils import ( - check_env_specs, - check_marl_grouping, - exploration_type, - ExplorationType, - make_composite_from_td, - MarlGroupMapType, - set_exploration_type, - step_mdp, -) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0969c694058..5c47d3f46d6 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -19,21 +19,29 @@ from warnings import warn import torch -from tensordict import (LazyStackedTensorDict, TensorDict, TensorDictBase, is_tensor_collection, unravel_key) +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.base import _is_leaf_nontensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp + from torchrl._utils import ( - VERBOSE, - _ProcessNoWarn, _check_for_faulty_process, _make_ordinal_device, + _ProcessNoWarn, logger as torchrl_logger, + VERBOSE, ) from torchrl.data.tensor_specs import Composite, NonTensor -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING, contains_lazy_spec -from torchrl.envs.common import EnvBase, EnvMetaData, _EnvPostInit, _do_nothing +from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING +from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData from torchrl.envs.env_creator import get_env_metadata + # legacy from torchrl.envs.libs.envpool import ( # noqa: F401 MultiThreadedEnv, @@ -864,7 +872,6 @@ def _shutdown_workers(self) -> None: def _set_seed(self, seed: int | None): """This method is not used in batched envs.""" - pass @lazy def start(self) -> None: diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index bbd780aadd7..7bec24cb17b 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -2,8 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .chess import ChessEnv -from .llm import LLMEnv, LLMHashingEnv -from .pendulum import PendulumEnv -from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index d4bceffe014..b720f9ff1bb 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -439,4 +439,3 @@ def _set_seed(self, *args): .. note:: This environment has no randomness, so this method does nothing. """ - pass diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index cba2b610648..b74030998ae 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -13,9 +13,10 @@ import numpy as np import torch from tensordict import NonTensorData, TensorDict, TensorDictBase + from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded -from torchrl.envs.common import EnvBase, _EnvWrapper, _maybe_unlock +from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase class BaseInfoDictReader(metaclass=abc.ABCMeta): diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 1cff97c1d49..7bec24cb17b 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -2,27 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .brax import BraxEnv, BraxWrapper -from .dm_control import DMControlEnv, DMControlWrapper -from .envpool import MultiThreadedEnv, MultiThreadedEnvWrapper -from .gym import ( - gym_backend, - GymEnv, - GymWrapper, - MOGymEnv, - MOGymWrapper, - register_gym_spec_conversion, - set_gym_backend, -) -from .habitat import HabitatEnv -from .isaacgym import IsaacGymEnv, IsaacGymWrapper -from .jumanji import JumanjiEnv, JumanjiWrapper -from .meltingpot import MeltingpotEnv, MeltingpotWrapper -from .openml import OpenMLEnv -from .openspiel import OpenSpielEnv, OpenSpielWrapper -from .pettingzoo import PettingZooEnv, PettingZooWrapper -from .robohive import RoboHiveEnv -from .smacv2 import SMACv2Env, SMACv2Wrapper -from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper -from .vmas import VmasEnv, VmasWrapper diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index a8f80aa6baf..19647f55841 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -11,7 +11,8 @@ import numpy as np import torch -from torchrl._utils import VERBOSE, logger as torchrl_logger + +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.data.tensor_specs import ( Bounded, Categorical, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 47a718caf15..f5c8c160ff9 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -10,7 +10,7 @@ import warnings from copy import copy from types import ModuleType -from typing import Dict, List +from typing import Dict from warnings import warn import numpy as np @@ -18,8 +18,10 @@ from packaging import version from tensordict import TensorDict, TensorDictBase from torch.utils._pytree import tree_map + from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( + _minmax_dtype, Binary, Bounded, Categorical, @@ -30,12 +32,11 @@ OneHot, TensorSpec, Unbounded, - _minmax_dtype, ) from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict from torchrl.envs.batched_envs import CloudpickleWrapper from torchrl.envs.common import _EnvPostInit -from torchrl.envs.gym_like import GymLikeEnv, default_info_dict_reader +from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv from torchrl.envs.utils import _classproperty try: diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index c8ff477fa06..6ef2dc8f130 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -5,14 +5,15 @@ from __future__ import annotations import importlib -from typing import Dict, List, Mapping, Sequence +from typing import Mapping, Sequence import torch from tensordict import TensorDict, TensorDictBase + from torchrl.data import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform -from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType _has_meltingpot = importlib.util.find_spec("meltingpot") is not None diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 5c248e476a9..7dbfeac07a8 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -9,6 +9,7 @@ import torch from tensordict import TensorDict, TensorDictBase + from torchrl.data.tensor_specs import ( Categorical, Composite, @@ -17,7 +18,7 @@ Unbounded, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType _has_pyspiel = importlib.util.find_spec("pyspiel") is not None diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 3b35b1e1f32..31fed35baa7 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -13,10 +13,11 @@ import packaging import torch from tensordict import TensorDictBase + from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend -from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType _has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None @@ -206,8 +207,8 @@ def available_envs(cls): def __init__( self, env: ( - pettingzoo.utils.env.ParallelEnv - | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 ) = None, return_state: bool = False, group_map: MarlGroupMapType | dict[str, list[str]] | None = None, @@ -271,8 +272,8 @@ def lib(self): def _build_env( self, env: ( - pettingzoo.utils.env.ParallelEnv - | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 ), ): import pettingzoo @@ -298,8 +299,8 @@ def _build_env( def _make_specs( self, env: ( - pettingzoo.utils.env.ParallelEnv - | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 ), ) -> None: # Set default for done on any or all @@ -998,8 +999,8 @@ def _build_env( parallel: bool, **kwargs, ) -> ( - pettingzoo.utils.env.ParallelEnv - | pettingzoo.utils.env.AECEnv # noqa: F821 # noqa: F821 + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 ): self.task_name = task diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index f471dee382c..c84450c7e94 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -11,9 +11,15 @@ import numpy as np import torch from tensordict import TensorDict + from torchrl.data.tensor_specs import Unbounded from torchrl.envs.common import _maybe_unlock -from torchrl.envs.libs.gym import (GymEnv, _GymAsyncMeta, _gym_to_torchrl_spec_transform, gym_backend) +from torchrl.envs.libs.gym import ( + _gym_to_torchrl_spec_transform, + _GymAsyncMeta, + gym_backend, + GymEnv, +) from torchrl.envs.utils import _classproperty, make_composite_from_td _has_gym = ( diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 5f5fc6c5654..149c20606ac 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -9,6 +9,7 @@ import torch from tensordict import TensorDict, TensorDictBase + from torchrl.data.tensor_specs import ( BoundedContinuous, Categorical, @@ -18,7 +19,7 @@ Unbounded, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.utils import MarlGroupMapType, _classproperty, check_marl_grouping +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType _has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 6e1bb1237e6..772e79b497e 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -6,10 +6,10 @@ import importlib.util import warnings -from typing import Dict, List import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase + from torchrl.data.tensor_specs import ( Bounded, Categorical, @@ -23,9 +23,14 @@ Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict -from torchrl.envs.common import EnvBase, _EnvWrapper +from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.libs.gym import gym_backend, set_gym_backend -from torchrl.envs.utils import (MarlGroupMapType, _classproperty, _selective_unsqueeze, check_marl_grouping) +from torchrl.envs.utils import ( + _classproperty, + _selective_unsqueeze, + check_marl_grouping, + MarlGroupMapType, +) _has_vmas = importlib.util.find_spec("vmas") is not None diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 437146a4909..7bec24cb17b 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -2,6 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .common import ModelBasedEnvBase -from .dreamer import DreamerDecoder, DreamerEnv diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 736bb7a2c9a..7bec24cb17b 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -2,71 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .gym_transforms import EndOfLifeTransform -from .r3m import R3MTransform -from .rb_transforms import MultiStepTransform -from .rlhf import DataLoadingPrimer, KLRewardTransform -from .transforms import ( - ActionDiscretizer, - ActionMask, - AutoResetEnv, - AutoResetTransform, - BatchSizeTransform, - BinarizeReward, - BurnInTransform, - CatFrames, - CatTensors, - CenterCrop, - ClipTransform, - Compose, - ConditionalSkip, - Crop, - DeviceCastTransform, - DiscreteActionProjection, - DoubleToFloat, - DTypeCastTransform, - ExcludeTransform, - FiniteTensorDictCheck, - FlattenObservation, - FrameSkipTransform, - GrayScale, - gSDENoise, - Hash, - InitTracker, - LineariseRewards, - MultiAction, - NoopResetEnv, - ObservationNorm, - ObservationTransform, - PermuteTransform, - PinMemoryTransform, - RandomCropTensorDict, - RemoveEmptySpecs, - RenameTransform, - Resize, - Reward2GoTransform, - RewardClipping, - RewardScaling, - RewardSum, - SelectTransform, - SignTransform, - SqueezeTransform, - Stack, - StepCounter, - TargetReturn, - TensorDictPrimer, - TimeMaxPool, - Timer, - Tokenizer, - ToTensorImage, - TrajCounter, - Transform, - TransformedEnv, - UnaryTransform, - UnsqueezeTransform, - VecGymEnvTransform, - VecNorm, -) -from .vc1 import VC1Transform -from .vip import VIPRewardTransform, VIPTransform diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 8d2ac4ad858..37945c78d6b 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from copy import copy, deepcopy -from typing import Any, Callable, Iterable, List, Literal +from typing import Any, Callable, Iterable, Literal import torch from tensordict import ( @@ -19,6 +19,7 @@ from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import _zip_strict, is_seq_of_nested_key from torch import nn + from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param @@ -339,8 +340,8 @@ def __init__( self, dataloader: Iterable[Any], primers: Composite | None = None, - data_keys: List[NestedKey] | None = None, - data_specs: List[TensorSpec] | None = None, + data_keys: list[NestedKey] | None = None, + data_specs: list[TensorSpec] | None = None, example_data: Any = None, stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2ee42d19667..4960e880f67 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -318,7 +318,6 @@ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: def init(self, tensordict) -> None: """Runs init steps for the transform.""" - pass def _apply_transform(self, obs: torch.Tensor) -> None: """Applies the transform to a tensor or a leaf. @@ -1161,7 +1160,6 @@ def set_seed( def _set_seed(self, seed: int | None): """This method is not used in transformed envs.""" - pass def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): if tensordict is not None: diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 0b9afecff02..c00be2403c4 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -6,8 +6,9 @@ from typing import List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase, set_lazy_legacy +from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url + from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index edf90a4e85b..fafd97cb81b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -3,93 +3,5 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torchrl.modules.tensordict_module.common import DistributionalDQNnet -from .distributions import ( - Delta, - distributions_maps, - IndependentNormal, - MaskedCategorical, - MaskedOneHotCategorical, - NormalParamExtractor, - NormalParamWrapper, - OneHotCategorical, - OneHotOrdinal, - Ordinal, - ReparamGradientStrategy, - TanhDelta, - TanhNormal, - TruncatedNormal, -) -from .models import ( - BatchRenorm1d, - ConsistentDropout, - ConsistentDropoutModule, - Conv3dNet, - ConvNet, - DdpgCnnActor, - DdpgCnnQNet, - DdpgMlpActor, - DdpgMlpQNet, - DecisionTransformer, - DreamerActor, - DTActor, - DuelingCnnDQNet, - MLP, - MultiAgentConvNet, - MultiAgentMLP, - MultiAgentNetBase, - NoisyLazyLinear, - NoisyLinear, - ObsDecoder, - ObsEncoder, - OnlineDTActor, - QMixer, - reset_noise, - RSSMPosterior, - RSSMPrior, - RSSMRollout, - Squeeze2dLayer, - SqueezeLayer, - VDNMixer, -) -from .tensordict_module import ( - Actor, - ActorCriticOperator, - ActorCriticWrapper, - ActorValueOperator, - AdditiveGaussianModule, - AdditiveGaussianWrapper, - DecisionTransformerInferenceWrapper, - DistributionalQValueActor, - DistributionalQValueHook, - DistributionalQValueModule, - EGreedyModule, - EGreedyWrapper, - GRU, - GRUCell, - GRUModule, - LMHeadActorValueOperator, - LSTM, - LSTMCell, - LSTMModule, - MultiStepActorWrapper, - OrnsteinUhlenbeckProcessModule, - OrnsteinUhlenbeckProcessWrapper, - ProbabilisticActor, - QValueActor, - QValueHook, - QValueModule, - recurrent_mode, - SafeModule, - SafeProbabilisticModule, - SafeProbabilisticTensorDictSequential, - SafeSequential, - set_recurrent_mode, - TanhModule, - ValueOperator, - VmapModule, - WorldModelWrapper, -) -from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 8f1b7da49a5..5ad3b972dec 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -3,24 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from tensordict.nn import NormalParamExtractor from torch import distributions as torch_dist -from .continuous import ( - Delta, - IndependentNormal, - NormalParamWrapper, - TanhDelta, - TanhNormal, - TruncatedNormal, -) +from .continuous import Delta, IndependentNormal, TanhDelta, TanhNormal, TruncatedNormal from .discrete import ( MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, OneHotOrdinal, Ordinal, - ReparamGradientStrategy, ) distributions_maps = { diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 460a712d3a2..48461e21ed6 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -14,12 +14,13 @@ from torch import distributions as D, nn from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform + from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, ) from torchrl.modules.distributions.utils import ( - FasterTransformedDistribution, _cast_device, + FasterTransformedDistribution, safeatanh_noeps, safetanh_noeps, ) diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index f2ccf04c0fe..4c6e14cdff2 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -10,7 +10,7 @@ from numbers import Number import torch -from torch.distributions import Distribution, constraints +from torch.distributions import constraints, Distribution from torch.distributions.utils import broadcast_all CONST_SQRT_2 = math.sqrt(2) diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 35a060e8d69..7bec24cb17b 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -2,47 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - - -from torchrl.modules.tensordict_module.common import DistributionalDQNnet - -from .batchrenorm import BatchRenorm1d - -from .decision_transformer import DecisionTransformer -from .exploration import ( - ConsistentDropout, - ConsistentDropoutModule, - NoisyLazyLinear, - NoisyLinear, - reset_noise, -) -from .model_based import ( - DreamerActor, - ObsDecoder, - ObsEncoder, - RSSMPosterior, - RSSMPrior, - RSSMRollout, -) -from .models import ( - Conv2dNet, - Conv3dNet, - ConvNet, - DdpgCnnActor, - DdpgCnnQNet, - DdpgMlpActor, - DdpgMlpQNet, - DTActor, - DuelingCnnDQNet, - DuelingMlpDQNet, - MLP, - OnlineDTActor, -) -from .multiagent import ( - MultiAgentConvNet, - MultiAgentMLP, - MultiAgentNetBase, - QMixer, - VDNMixer, -) -from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index d5ac3157b8d..3458d494029 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -17,10 +17,11 @@ from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter + from torchrl._utils import prod from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS -from torchrl.envs.utils import ExplorationType, exploration_type +from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.distributions.utils import _cast_transform_device from torchrl.modules.utils import inv_softplus diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 711e7cda87f..c1ce2b96f2b 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -7,20 +7,21 @@ import dataclasses from copy import deepcopy from numbers import Number -from typing import Callable, Dict, Sequence, Type +from typing import Callable, Sequence import torch from torch import nn + from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING from torchrl.modules.models.decision_transformer import DecisionTransformer from torchrl.modules.models.utils import ( + _find_depth, + create_on_device, LazyMapping, SquashDims, Squeeze2dLayer, SqueezeLayer, - _find_depth, - create_on_device, ) from torchrl.modules.tensordict_module.common import DistributionalDQNnet # noqa diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 8dd7cbe8743..e4f923d34dd 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -7,7 +7,7 @@ import abc from copy import deepcopy from textwrap import indent -from typing import Sequence, Type +from typing import Sequence import numpy as np import torch diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 510ee423a64..1ae6234a844 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -6,7 +6,7 @@ import inspect import warnings -from typing import Callable, Sequence, Type +from typing import Callable, Sequence import torch from torch import nn diff --git a/torchrl/modules/planners/__init__.py b/torchrl/modules/planners/__init__.py index 56c0e48bc65..7bec24cb17b 100644 --- a/torchrl/modules/planners/__init__.py +++ b/torchrl/modules/planners/__init__.py @@ -2,7 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .cem import CEMPlanner -from .common import MPCPlannerBase -from .mppi import MPPIPlanner diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 3fb1559833a..7bec24cb17b 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -2,47 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .actors import ( - Actor, - ActorCriticOperator, - ActorCriticWrapper, - ActorValueOperator, - DecisionTransformerInferenceWrapper, - DistributionalQValueActor, - DistributionalQValueHook, - DistributionalQValueModule, - LMHeadActorValueOperator, - MultiStepActorWrapper, - ProbabilisticActor, - QValueActor, - QValueHook, - QValueModule, - TanhModule, - ValueOperator, -) -from .common import SafeModule, VmapModule -from .exploration import ( - AdditiveGaussianModule, - AdditiveGaussianWrapper, - EGreedyModule, - EGreedyWrapper, - OrnsteinUhlenbeckProcessModule, - OrnsteinUhlenbeckProcessWrapper, -) -from .probabilistic import ( - SafeProbabilisticModule, - SafeProbabilisticTensorDictSequential, -) -from .rnn import ( - GRU, - GRUCell, - GRUModule, - LSTM, - LSTMCell, - LSTMModule, - recurrent_mode, - set_recurrent_mode, -) -from .sequence import SafeSequential -from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 4a33d0d7fbb..e4b91c1a543 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,21 +4,22 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional, Sequence +from typing import Sequence import torch from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( CompositeDistribution, + dispatch, TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, TensorDictSequential, - dispatch, ) -from tensordict.utils import NestedKey, expand_as_right +from tensordict.utils import expand_as_right, NestedKey from torch import nn from torch.distributions import Categorical + from torchrl._utils import _replace_last from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _process_action_space_spec diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index f2b40e645e4..2bd09e81e81 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -13,10 +13,11 @@ import torch from tensordict import TensorDictBase, unravel_key_list -from tensordict.nn import TensorDictModule, TensorDictModuleBase, dispatch +from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase from tensordict.utils import NestedKey from torch import nn from torch.nn import functional as F + from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import DEVICE_TYPING diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 5bbe40a725d..050b8e4f27e 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -14,10 +14,11 @@ TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.utils import NestedKey, expand_as_right, expand_right +from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn + from torchrl.data.tensor_specs import Composite, TensorSpec -from torchrl.envs.utils import ExplorationType, exploration_type +from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.tensordict_module.common import _forward_hook_safe_action __all__ = [ diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index e03685d55d8..89e56672623 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -5,7 +5,6 @@ from __future__ import annotations import warnings -from typing import Dict, List import torch from tensordict import TensorDictBase, unravel_key_list diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index be23b985498..cf6b7ad5533 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -12,10 +12,11 @@ import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list from tensordict.base import NO_DEFAULT -from tensordict.nn import TensorDictModuleBase as ModuleBase, dispatch +from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase from tensordict.utils import expand_as_right, prod, set_lazy_legacy -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase + from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index ae57de949bf..109744cb7ee 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,16 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import OrderedDict import torch from packaging import version if version.parse(torch.__version__) >= version.parse("1.12.0"): - from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta + from torch.nn.parameter import _ParameterMeta else: - from torch.nn.parameter import _disabled_torch_function_impl + pass # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): @@ -22,7 +21,3 @@ def __instancecheck__(self, instance): isinstance(instance, torch.Tensor) and getattr(instance, "_is_param", False) ) - - -from .mappings import biased_softplus, inv_softplus, mappings -from .utils import get_primers_from_module diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8f5636db95..7bec24cb17b 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -2,33 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .a2c import A2CLoss -from .common import LossModule -from .cql import CQLLoss, DiscreteCQLLoss -from .crossq import CrossQLoss -from .ddpg import DDPGLoss -from .decision_transformer import DTLoss, OnlineDTLoss -from .dqn import DistributionalDQNLoss, DQNLoss -from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss -from .gail import GAILLoss -from .iql import DiscreteIQLLoss, IQLLoss -from .multiagent import QMixerLoss -from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss -from .redq import REDQLoss -from .reinforce import ReinforceLoss -from .sac import DiscreteSACLoss, SACLoss -from .td3 import TD3Loss -from .td3_bc import TD3BCLoss -from .utils import ( - default_value_kwargs, - distance_loss, - group_optimizers, - HardUpdate, - hold_out_net, - hold_out_params, - next_state_value, - SoftUpdate, - TargetNetUpdater, - ValueEstimators, -) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 369f1d56ad0..2ebcd4120c7 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -9,28 +9,34 @@ from dataclasses import dataclass import torch -from tensordict import (TensorDict, TensorDictBase, TensorDictParams, is_tensor_collection) +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import ( - CompositeDistribution, - ProbabilisticTensorDictSequential, - TensorDictModule, composite_lp_aggregate, + CompositeDistribution, dispatch, + ProbabilisticTensorDictSequential, set_composite_lp_aggregate, + TensorDictModule, ) from tensordict.utils import NestedKey from torch import distributions as d + from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _clip_value_loss, + _GAMMA_LMBDA_DEPREC_ERROR, _get_default_device, _reduce, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import ( GAE, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 33b00ff969b..f74bcc3b36f 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -13,11 +13,12 @@ from typing import Iterator import torch -from tensordict import TensorDict, TensorDictBase, is_tensor_collection +from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from tensordict.utils import Buffer from torch import nn from torch.nn import Parameter + from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import set_recurrent_mode @@ -124,8 +125,6 @@ class _AcceptedKeys: default values. """ - pass - tensor_keys: _AcceptedKeys _vmap_randomness = None default_value_estimator: ValueEstimators = None diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index fd6c4231312..3ca291c09ae 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -8,15 +8,15 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Union import numpy as np import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey, unravel_key from torch import Tensor + from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -24,13 +24,13 @@ from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 4bf314bce4c..2f576f219b3 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -10,20 +10,21 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 6dd10de7e89..50973c7077f 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -10,17 +10,18 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey, unravel_key + from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, _reduce, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 03cdc42e667..1038a7151e8 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -9,9 +9,10 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d + from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import _reduce, distance_loss diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index a9eeae6d5d4..7221f6b3be5 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -11,14 +11,20 @@ import numpy as np import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, composite_lp_aggregate, dispatch +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp -from torchrl.objectives import ValueEstimators, default_value_kwargs, distance_loss +from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import (_GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _reduce, _vmap_func) +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, + _vmap_func, +) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index ad5b7711b5b..6a4373e1751 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -9,9 +9,10 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import nn + from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import step_mdp @@ -22,11 +23,11 @@ from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TDLambdaEstimator from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 054b9de9d2c..0eb976da5d0 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -10,16 +10,17 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey + from torchrl._utils import timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, hold_out_net, + ValueEstimators, ) # distance_loss, from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 087204d1ad6..58057636cf3 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -9,20 +9,21 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py index 7340cffd841..7bec24cb17b 100644 --- a/torchrl/objectives/multiagent/__init__.py +++ b/torchrl/objectives/multiagent/__init__.py @@ -2,5 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .qmixer import QMixerLoss diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 66307403283..ce5752e70b0 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -11,9 +11,10 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import nn + from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space from torchrl.modules import SafeSequential @@ -21,11 +22,11 @@ from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TDLambdaEstimator from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5cc5795584c..db887cf0fba 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -10,31 +10,37 @@ from dataclasses import dataclass import torch -from tensordict import (TensorDict, TensorDictBase, TensorDictParams, is_tensor_collection) +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import ( + composite_lp_aggregate, CompositeDistribution, + dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, - TensorDictModule, - composite_lp_aggregate, - dispatch, set_composite_lp_aggregate, + TensorDictModule, ) from tensordict.utils import NestedKey from torch import distributions as d + from torchrl._utils import _standardize from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, _clip_value_loss, + _GAMMA_LMBDA_DEPREC_ERROR, _maybe_add_or_extend_key, _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import ( GAE, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index fca6375a634..6be5172f0fd 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -10,20 +10,21 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, composite_lp_aggregate, dispatch +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 71e01a36883..d34313e5d8e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -14,14 +14,15 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import ( - CompositeDistribution, - TensorDictModule, composite_lp_aggregate, + CompositeDistribution, dispatch, set_composite_lp_aggregate, + TensorDictModule, ) -from tensordict.utils import NestedKey, expand_right +from tensordict.utils import expand_right, NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -29,13 +30,13 @@ from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 707193ba653..40760ef95fb 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -8,19 +8,20 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey + from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, - _GAMMA_LMBDA_DEPREC_ERROR, _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index b64ad0758b9..b7292c4fdb2 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -8,18 +8,19 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import TensorDictModule, dispatch +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey + from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - ValueEstimators, _cache_values, _reduce, _vmap_func, default_value_kwargs, distance_loss, + ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 5c8c760ca5a..4a7d8466ddb 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -13,7 +13,7 @@ import torch from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F from torch.nn.modules import dropout @@ -321,7 +321,7 @@ class SoftUpdate(TargetNetUpdater): def __init__( self, loss_module: ( - DQNLoss + DQNLoss # noqa: F821 | DDPGLoss # noqa: F821 | SACLoss # noqa: F821 | REDQLoss # noqa: F821 diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 51496986153..7bec24cb17b 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -2,15 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .advantages import ( - GAE, - TD0Estimate, - TD0Estimator, - TD1Estimate, - TD1Estimator, - TDLambdaEstimate, - TDLambdaEstimator, - ValueEstimatorBase, - VTrace, -) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 69bbcc91721..4f12809809e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -10,25 +10,31 @@ from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps -from typing import Callable, Union +from typing import Callable import torch -from tensordict import TensorDictBase, is_tensor_collection +from tensordict import is_tensor_collection, TensorDictBase from tensordict.nn import ( - ProbabilisticTensorDictModule, - TensorDictModule, - TensorDictModuleBase, composite_lp_aggregate, dispatch, + ProbabilisticTensorDictModule, set_composite_lp_aggregate, set_skip_existing, + TensorDictModule, + TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type from tensordict.utils import NestedKey, unravel_key from torch import Tensor + from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import (RANDOM_MODULE_LIST, _maybe_get_or_select, _vmap_func, hold_out_net) +from torchrl.objectives.utils import ( + _maybe_get_or_select, + _vmap_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -283,7 +289,6 @@ def in_keys(self): except AttributeError: # value network does not have an `in_keys` attribute in_keys = [] - pass return in_keys @property diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 87bdf1c630d..8484a025835 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -7,7 +7,6 @@ import math import warnings from functools import wraps -from typing import Union import torch diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index f6c9bcdefbb..7bec24cb17b 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -2,6 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger -from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder diff --git a/torchrl/record/loggers/__init__.py b/torchrl/record/loggers/__init__.py index 92714675046..7bec24cb17b 100644 --- a/torchrl/record/loggers/__init__.py +++ b/torchrl/record/loggers/__init__.py @@ -2,12 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .common import Logger - -from .csv import CSVLogger -from .mlflow import MLFlowLogger -from .tensorboard import TensorboardLogger -from .utils import generate_exp_name, get_logger - -from .wandb import WandbLogger diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index b2a68964571..c2cf93dd119 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -7,7 +7,7 @@ import importlib.util import math from copy import copy -from typing import Callable, List, Sequence +from typing import Callable, Sequence import numpy as np import torch diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 9d593d64f17..7bec24cb17b 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -2,21 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .trainers import ( - BatchSubSampler, - ClearCudaCache, - CountFramesLog, - LogReward, - LogScalar, - LogValidationReward, - mask_batch, - OptimizerHook, - Recorder, - ReplayBufferTrainer, - RewardNormalizer, - SelectKeys, - Trainer, - TrainerHookBase, - UpdateWeights, -) diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index b09becdc15a..7bec24cb17b 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -2,21 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .collectors import ( - make_collector_offpolicy, - make_collector_onpolicy, - sync_async_collector, - sync_sync_collector, -) -from .envs import ( - correct_for_frame_skip, - get_stats_random_rollout, - parallel_env_constructor, - transformed_env_constructor, -) -from .logger import LoggerConfig -from .losses import make_dqn_loss, make_target_updater -from .models import make_dqn_actor, make_dreamer -from .replay_buffer import make_replay_buffer -from .trainers import make_trainer diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index cb76a1bcd26..d70a7358eb0 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -8,18 +8,24 @@ import abc import pathlib import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict, OrderedDict from copy import deepcopy from textwrap import indent from typing import Any, Callable, Sequence, Tuple import numpy as np import torch.nn -from tensordict import TensorDictBase, pad +from tensordict import pad, TensorDictBase from tensordict.nn import TensorDictModule from tensordict.utils import expand_right from torch import nn, optim -from torchrl._utils import (KeyDependentDefaultDict, VERBOSE, _CKPT_BACKEND, logger as torchrl_logger) + +from torchrl._utils import ( + _CKPT_BACKEND, + KeyDependentDefaultDict, + logger as torchrl_logger, + VERBOSE, +) from torchrl.collectors.collectors import DataCollectorBase from torchrl.collectors.utils import split_trajectories from torchrl.data.replay_buffers import ( diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 4c94177b880..06d89ad63ad 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -108,7 +108,6 @@ # If we can't set the method globally we can still run the parallel env with "fork" # This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"` mp_context = "fork" - pass # sphinx_gallery_end_ignore import os From 0df5d7b65e0f0835a597272f7ea320b3766a2c1a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 18:47:12 +0000 Subject: [PATCH 03/10] Update [ghstack-poisoned] --- torchrl/__init__.py | 14 +++ torchrl/collectors/__init__.py | 10 ++ torchrl/collectors/distributed/__init__.py | 6 + torchrl/data/map/__init__.py | 5 + torchrl/data/postprocs/__init__.py | 2 + torchrl/data/replay_buffers/__init__.py | 46 +++++++ torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/rlhf/__init__.py | 10 ++ torchrl/envs/__init__.py | 119 ++++++++++++++++++ torchrl/envs/custom/__init__.py | 5 + torchrl/envs/libs/__init__.py | 24 ++++ torchrl/envs/model_based/__init__.py | 3 + torchrl/envs/transforms/__init__.py | 68 ++++++++++ torchrl/modules/__init__.py | 88 +++++++++++++ torchrl/modules/distributions/__init__.py | 11 +- torchrl/modules/models/__init__.py | 44 +++++++ torchrl/modules/planners/__init__.py | 4 + torchrl/modules/tensordict_module/__init__.py | 44 +++++++ torchrl/modules/utils/__init__.py | 9 +- torchrl/objectives/__init__.py | 30 +++++ torchrl/objectives/multiagent/__init__.py | 2 + torchrl/objectives/value/__init__.py | 12 ++ torchrl/record/__init__.py | 3 + torchrl/record/loggers/__init__.py | 9 ++ torchrl/trainers/__init__.py | 18 +++ torchrl/trainers/helpers/__init__.py | 18 +++ 26 files changed, 602 insertions(+), 4 deletions(-) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 81e97018596..42819ef1d8c 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -46,6 +46,20 @@ ) +import torchrl.collectors +import torchrl.data +import torchrl.envs +import torchrl.modules +import torchrl.objectives +import torchrl.trainers +from torchrl._utils import ( + auto_unwrap_transformed_env, + compile_with_warmup, + implement_for, + set_auto_unwrap_transformed_env, + timeit, +) + # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. filter_warnings_subprocess = True diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7bec24cb17b..d69d8c9e50c 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -2,3 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from torchrl.envs.utils import RandomPolicy + +from .collectors import ( + aSyncDataCollector, + DataCollectorBase, + MultiaSyncDataCollector, + MultiSyncDataCollector, + SyncDataCollector, +) diff --git a/torchrl/collectors/distributed/__init__.py b/torchrl/collectors/distributed/__init__.py index 7bec24cb17b..97932619a60 100644 --- a/torchrl/collectors/distributed/__init__.py +++ b/torchrl/collectors/distributed/__init__.py @@ -2,3 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .generic import DEFAULT_SLURM_CONF, DistributedDataCollector +from .ray import RayCollector +from .rpc import RPCDataCollector +from .sync import DistributedSyncDataCollector +from .utils import submitit_delayed_launcher diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index 7bec24cb17b..c9bc25477c2 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -2,3 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .hash import BinaryToDecimal, RandomProjectionHash, SipHash +from .query import HashToInt, QueryModule +from .tdstorage import TensorDictMap, TensorMap +from .tree import MCTSForest, Tree diff --git a/torchrl/data/postprocs/__init__.py b/torchrl/data/postprocs/__init__.py index 7bec24cb17b..707740f6946 100644 --- a/torchrl/data/postprocs/__init__.py +++ b/torchrl/data/postprocs/__init__.py @@ -2,3 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .postprocs import MultiStep diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 7bec24cb17b..4f230f30701 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -2,3 +2,49 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .checkpointers import ( + FlatStorageCheckpointer, + H5StorageCheckpointer, + ListStorageCheckpointer, + NestedStorageCheckpointer, + StorageCheckpointerBase, + StorageEnsembleCheckpointer, + TensorStorageCheckpointer, +) +from .replay_buffers import ( + PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, + ReplayBuffer, + ReplayBufferEnsemble, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from .samplers import ( + PrioritizedSampler, + PrioritizedSliceSampler, + RandomSampler, + Sampler, + SamplerEnsemble, + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from .storages import ( + LazyMemmapStorage, + LazyStackStorage, + LazyTensorStorage, + ListStorage, + Storage, + StorageEnsemble, + TensorStorage, +) +from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested +from .writers import ( + ImmutableDatasetWriter, + RoundRobinWriter, + TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, + Writer, + WriterEnsemble, +) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 3515b8eb8e4..7b54008a752 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -416,7 +416,7 @@ def __init__( max_size: int | None = None, *, compilable: bool = False, - stack_dim: int = -1, + stack_dim: int = 0, ): super().__init__(max_size=max_size, compilable=compilable) self.stack_dim = stack_dim diff --git a/torchrl/data/rlhf/__init__.py b/torchrl/data/rlhf/__init__.py index 7bec24cb17b..f0db092f2d1 100644 --- a/torchrl/data/rlhf/__init__.py +++ b/torchrl/data/rlhf/__init__.py @@ -2,3 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .dataset import ( + create_infinite_iterator, + get_dataloader, + TensorDictTokenizer, + TokenizedDatasetLoader, +) +from .prompt import PromptData, PromptTensorDictTokenizer +from .reward import PairwiseDataset, RewardData +from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 7bec24cb17b..84f9fa8b0a6 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -2,3 +2,122 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .batched_envs import ParallelEnv, SerialEnv +from .common import EnvBase, EnvMetaData, make_tensordict +from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .env_creator import env_creator, EnvCreator, get_env_metadata +from .gym_like import default_info_dict_reader, GymLikeEnv +from .libs import ( + BraxEnv, + BraxWrapper, + DMControlEnv, + DMControlWrapper, + gym_backend, + GymEnv, + GymWrapper, + HabitatEnv, + IsaacGymEnv, + IsaacGymWrapper, + JumanjiEnv, + JumanjiWrapper, + MeltingpotEnv, + MeltingpotWrapper, + MOGymEnv, + MOGymWrapper, + MultiThreadedEnv, + MultiThreadedEnvWrapper, + OpenMLEnv, + OpenSpielEnv, + OpenSpielWrapper, + PettingZooEnv, + PettingZooWrapper, + register_gym_spec_conversion, + RoboHiveEnv, + set_gym_backend, + SMACv2Env, + SMACv2Wrapper, + UnityMLAgentsEnv, + UnityMLAgentsWrapper, + VmasEnv, + VmasWrapper, +) +from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase +from .transforms import ( + ActionDiscretizer, + ActionMask, + AutoResetEnv, + AutoResetTransform, + BatchSizeTransform, + BinarizeReward, + BurnInTransform, + CatFrames, + CatTensors, + CenterCrop, + ClipTransform, + Compose, + ConditionalSkip, + Crop, + DeviceCastTransform, + DiscreteActionProjection, + DoubleToFloat, + DTypeCastTransform, + EndOfLifeTransform, + ExcludeTransform, + FiniteTensorDictCheck, + FlattenObservation, + FrameSkipTransform, + GrayScale, + gSDENoise, + Hash, + InitTracker, + KLRewardTransform, + LineariseRewards, + MultiAction, + MultiStepTransform, + NoopResetEnv, + ObservationNorm, + ObservationTransform, + PermuteTransform, + PinMemoryTransform, + R3MTransform, + RandomCropTensorDict, + RemoveEmptySpecs, + RenameTransform, + Resize, + Reward2GoTransform, + RewardClipping, + RewardScaling, + RewardSum, + SelectTransform, + SignTransform, + SqueezeTransform, + Stack, + StepCounter, + TargetReturn, + TensorDictPrimer, + TimeMaxPool, + Timer, + Tokenizer, + ToTensorImage, + TrajCounter, + Transform, + TransformedEnv, + UnaryTransform, + UnsqueezeTransform, + VC1Transform, + VecGymEnvTransform, + VecNorm, + VIPRewardTransform, + VIPTransform, +) +from .utils import ( + check_env_specs, + check_marl_grouping, + exploration_type, + ExplorationType, + make_composite_from_td, + MarlGroupMapType, + set_exploration_type, + step_mdp, +) diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 7bec24cb17b..d2c85a7198f 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -2,3 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .chess import ChessEnv +from .llm import LLMHashingEnv +from .pendulum import PendulumEnv +from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 7bec24cb17b..1cff97c1d49 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -2,3 +2,27 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .brax import BraxEnv, BraxWrapper +from .dm_control import DMControlEnv, DMControlWrapper +from .envpool import MultiThreadedEnv, MultiThreadedEnvWrapper +from .gym import ( + gym_backend, + GymEnv, + GymWrapper, + MOGymEnv, + MOGymWrapper, + register_gym_spec_conversion, + set_gym_backend, +) +from .habitat import HabitatEnv +from .isaacgym import IsaacGymEnv, IsaacGymWrapper +from .jumanji import JumanjiEnv, JumanjiWrapper +from .meltingpot import MeltingpotEnv, MeltingpotWrapper +from .openml import OpenMLEnv +from .openspiel import OpenSpielEnv, OpenSpielWrapper +from .pettingzoo import PettingZooEnv, PettingZooWrapper +from .robohive import RoboHiveEnv +from .smacv2 import SMACv2Env, SMACv2Wrapper +from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper +from .vmas import VmasEnv, VmasWrapper diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 7bec24cb17b..437146a4909 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -2,3 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .common import ModelBasedEnvBase +from .dreamer import DreamerDecoder, DreamerEnv diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7bec24cb17b..de12f1a0302 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -2,3 +2,71 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .gym_transforms import EndOfLifeTransform +from .r3m import R3MTransform +from .rb_transforms import MultiStepTransform +from .rlhf import KLRewardTransform +from .transforms import ( + ActionDiscretizer, + ActionMask, + AutoResetEnv, + AutoResetTransform, + BatchSizeTransform, + BinarizeReward, + BurnInTransform, + CatFrames, + CatTensors, + CenterCrop, + ClipTransform, + Compose, + ConditionalSkip, + Crop, + DeviceCastTransform, + DiscreteActionProjection, + DoubleToFloat, + DTypeCastTransform, + ExcludeTransform, + FiniteTensorDictCheck, + FlattenObservation, + FrameSkipTransform, + GrayScale, + gSDENoise, + Hash, + InitTracker, + LineariseRewards, + MultiAction, + NoopResetEnv, + ObservationNorm, + ObservationTransform, + PermuteTransform, + PinMemoryTransform, + RandomCropTensorDict, + RemoveEmptySpecs, + RenameTransform, + Resize, + Reward2GoTransform, + RewardClipping, + RewardScaling, + RewardSum, + SelectTransform, + SignTransform, + SqueezeTransform, + Stack, + StepCounter, + TargetReturn, + TensorDictPrimer, + TimeMaxPool, + Timer, + Tokenizer, + ToTensorImage, + TrajCounter, + Transform, + TransformedEnv, + UnaryTransform, + UnsqueezeTransform, + VecGymEnvTransform, + VecNorm, +) +from .vc1 import VC1Transform +from .vip import VIPRewardTransform, VIPTransform diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index fafd97cb81b..edf90a4e85b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -3,5 +3,93 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.common import DistributionalDQNnet +from .distributions import ( + Delta, + distributions_maps, + IndependentNormal, + MaskedCategorical, + MaskedOneHotCategorical, + NormalParamExtractor, + NormalParamWrapper, + OneHotCategorical, + OneHotOrdinal, + Ordinal, + ReparamGradientStrategy, + TanhDelta, + TanhNormal, + TruncatedNormal, +) +from .models import ( + BatchRenorm1d, + ConsistentDropout, + ConsistentDropoutModule, + Conv3dNet, + ConvNet, + DdpgCnnActor, + DdpgCnnQNet, + DdpgMlpActor, + DdpgMlpQNet, + DecisionTransformer, + DreamerActor, + DTActor, + DuelingCnnDQNet, + MLP, + MultiAgentConvNet, + MultiAgentMLP, + MultiAgentNetBase, + NoisyLazyLinear, + NoisyLinear, + ObsDecoder, + ObsEncoder, + OnlineDTActor, + QMixer, + reset_noise, + RSSMPosterior, + RSSMPrior, + RSSMRollout, + Squeeze2dLayer, + SqueezeLayer, + VDNMixer, +) +from .tensordict_module import ( + Actor, + ActorCriticOperator, + ActorCriticWrapper, + ActorValueOperator, + AdditiveGaussianModule, + AdditiveGaussianWrapper, + DecisionTransformerInferenceWrapper, + DistributionalQValueActor, + DistributionalQValueHook, + DistributionalQValueModule, + EGreedyModule, + EGreedyWrapper, + GRU, + GRUCell, + GRUModule, + LMHeadActorValueOperator, + LSTM, + LSTMCell, + LSTMModule, + MultiStepActorWrapper, + OrnsteinUhlenbeckProcessModule, + OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, + QValueActor, + QValueHook, + QValueModule, + recurrent_mode, + SafeModule, + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, + SafeSequential, + set_recurrent_mode, + TanhModule, + ValueOperator, + VmapModule, + WorldModelWrapper, +) +from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 5ad3b972dec..8f1b7da49a5 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -3,15 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from tensordict.nn import NormalParamExtractor from torch import distributions as torch_dist -from .continuous import Delta, IndependentNormal, TanhDelta, TanhNormal, TruncatedNormal +from .continuous import ( + Delta, + IndependentNormal, + NormalParamWrapper, + TanhDelta, + TanhNormal, + TruncatedNormal, +) from .discrete import ( MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, OneHotOrdinal, Ordinal, + ReparamGradientStrategy, ) distributions_maps = { diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 7bec24cb17b..35a060e8d69 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -2,3 +2,47 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + + +from torchrl.modules.tensordict_module.common import DistributionalDQNnet + +from .batchrenorm import BatchRenorm1d + +from .decision_transformer import DecisionTransformer +from .exploration import ( + ConsistentDropout, + ConsistentDropoutModule, + NoisyLazyLinear, + NoisyLinear, + reset_noise, +) +from .model_based import ( + DreamerActor, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, +) +from .models import ( + Conv2dNet, + Conv3dNet, + ConvNet, + DdpgCnnActor, + DdpgCnnQNet, + DdpgMlpActor, + DdpgMlpQNet, + DTActor, + DuelingCnnDQNet, + DuelingMlpDQNet, + MLP, + OnlineDTActor, +) +from .multiagent import ( + MultiAgentConvNet, + MultiAgentMLP, + MultiAgentNetBase, + QMixer, + VDNMixer, +) +from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/planners/__init__.py b/torchrl/modules/planners/__init__.py index 7bec24cb17b..56c0e48bc65 100644 --- a/torchrl/modules/planners/__init__.py +++ b/torchrl/modules/planners/__init__.py @@ -2,3 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .cem import CEMPlanner +from .common import MPCPlannerBase +from .mppi import MPPIPlanner diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 7bec24cb17b..3fb1559833a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -2,3 +2,47 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .actors import ( + Actor, + ActorCriticOperator, + ActorCriticWrapper, + ActorValueOperator, + DecisionTransformerInferenceWrapper, + DistributionalQValueActor, + DistributionalQValueHook, + DistributionalQValueModule, + LMHeadActorValueOperator, + MultiStepActorWrapper, + ProbabilisticActor, + QValueActor, + QValueHook, + QValueModule, + TanhModule, + ValueOperator, +) +from .common import SafeModule, VmapModule +from .exploration import ( + AdditiveGaussianModule, + AdditiveGaussianWrapper, + EGreedyModule, + EGreedyWrapper, + OrnsteinUhlenbeckProcessModule, + OrnsteinUhlenbeckProcessWrapper, +) +from .probabilistic import ( + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, +) +from .rnn import ( + GRU, + GRUCell, + GRUModule, + LSTM, + LSTMCell, + LSTMModule, + recurrent_mode, + set_recurrent_mode, +) +from .sequence import SafeSequential +from .world_models import WorldModelWrapper diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index 109744cb7ee..ae57de949bf 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,15 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import OrderedDict import torch from packaging import version if version.parse(torch.__version__) >= version.parse("1.12.0"): - from torch.nn.parameter import _ParameterMeta + from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta else: - pass + from torch.nn.parameter import _disabled_torch_function_impl # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): @@ -21,3 +22,7 @@ def __instancecheck__(self, instance): isinstance(instance, torch.Tensor) and getattr(instance, "_is_param", False) ) + + +from .mappings import biased_softplus, inv_softplus, mappings +from .utils import get_primers_from_module diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 7bec24cb17b..f8f5636db95 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -2,3 +2,33 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .a2c import A2CLoss +from .common import LossModule +from .cql import CQLLoss, DiscreteCQLLoss +from .crossq import CrossQLoss +from .ddpg import DDPGLoss +from .decision_transformer import DTLoss, OnlineDTLoss +from .dqn import DistributionalDQNLoss, DQNLoss +from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss +from .gail import GAILLoss +from .iql import DiscreteIQLLoss, IQLLoss +from .multiagent import QMixerLoss +from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss +from .redq import REDQLoss +from .reinforce import ReinforceLoss +from .sac import DiscreteSACLoss, SACLoss +from .td3 import TD3Loss +from .td3_bc import TD3BCLoss +from .utils import ( + default_value_kwargs, + distance_loss, + group_optimizers, + HardUpdate, + hold_out_net, + hold_out_params, + next_state_value, + SoftUpdate, + TargetNetUpdater, + ValueEstimators, +) diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py index 7bec24cb17b..7340cffd841 100644 --- a/torchrl/objectives/multiagent/__init__.py +++ b/torchrl/objectives/multiagent/__init__.py @@ -2,3 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .qmixer import QMixerLoss diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 7bec24cb17b..51496986153 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -2,3 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .advantages import ( + GAE, + TD0Estimate, + TD0Estimator, + TD1Estimate, + TD1Estimator, + TDLambdaEstimate, + TDLambdaEstimator, + ValueEstimatorBase, + VTrace, +) diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index 7bec24cb17b..f6c9bcdefbb 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -2,3 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger +from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder diff --git a/torchrl/record/loggers/__init__.py b/torchrl/record/loggers/__init__.py index 7bec24cb17b..92714675046 100644 --- a/torchrl/record/loggers/__init__.py +++ b/torchrl/record/loggers/__init__.py @@ -2,3 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .common import Logger + +from .csv import CSVLogger +from .mlflow import MLFlowLogger +from .tensorboard import TensorboardLogger +from .utils import generate_exp_name, get_logger + +from .wandb import WandbLogger diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 7bec24cb17b..9d593d64f17 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -2,3 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .trainers import ( + BatchSubSampler, + ClearCudaCache, + CountFramesLog, + LogReward, + LogScalar, + LogValidationReward, + mask_batch, + OptimizerHook, + Recorder, + ReplayBufferTrainer, + RewardNormalizer, + SelectKeys, + Trainer, + TrainerHookBase, + UpdateWeights, +) diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index 7bec24cb17b..b09becdc15a 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -2,3 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .collectors import ( + make_collector_offpolicy, + make_collector_onpolicy, + sync_async_collector, + sync_sync_collector, +) +from .envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + transformed_env_constructor, +) +from .logger import LoggerConfig +from .losses import make_dqn_loss, make_target_updater +from .models import make_dqn_actor, make_dreamer +from .replay_buffer import make_replay_buffer +from .trainers import make_trainer From 3bc83f99f2ac6297da709f5d79550f34b9c30cc6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 18:47:45 +0000 Subject: [PATCH 04/10] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/storages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 7b54008a752..1d6a4ac69e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -382,7 +382,7 @@ class LazyStackStorage(ListStorage): Keyword Args: compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at the cost of being executable in multiprocessed settings. - stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`. + stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`. Examples: >>> import torch From d7a6812adadce73d61561a0d61873f4be4900b10 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 09:38:28 +0000 Subject: [PATCH 05/10] Update [ghstack-poisoned] --- torchrl/data/__init__.py | 194 ++++++++++++++++++++++++++++++++ torchrl/envs/custom/chess.py | 9 +- torchrl/envs/custom/llm.py | 4 +- torchrl/envs/libs/isaacgym.py | 2 +- torchrl/envs/libs/meltingpot.py | 2 +- torchrl/envs/libs/openml.py | 2 +- 6 files changed, 207 insertions(+), 6 deletions(-) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 7bec24cb17b..9fc7b77a6b8 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -2,3 +2,197 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from .map import ( + BinaryToDecimal, + HashToInt, + MCTSForest, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, + TensorMap, + Tree, +) +from .postprocs import MultiStep +from .replay_buffers import ( + Flat2TED, + FlatStorageCheckpointer, + H5Combine, + H5Split, + H5StorageCheckpointer, + ImmutableDatasetWriter, + LazyMemmapStorage, + LazyStackStorage, + LazyTensorStorage, + ListStorage, + ListStorageCheckpointer, + Nested2TED, + NestedStorageCheckpointer, + PrioritizedReplayBuffer, + PrioritizedSampler, + PrioritizedSliceSampler, + RandomSampler, + RemoteTensorDictReplayBuffer, + ReplayBuffer, + ReplayBufferEnsemble, + RoundRobinWriter, + SamplerEnsemble, + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, + Storage, + StorageCheckpointerBase, + StorageEnsemble, + StorageEnsembleCheckpointer, + TED2Flat, + TED2Nested, + TensorDictMaxValueWriter, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, + TensorDictRoundRobinWriter, + TensorStorage, + TensorStorageCheckpointer, + Writer, + WriterEnsemble, +) +from .rlhf import ( + AdaptiveKLController, + ConstantKLController, + create_infinite_iterator, + get_dataloader, + PairwiseDataset, + PromptData, + PromptTensorDictTokenizer, + RewardData, + RolloutFromModel, + TensorDictTokenizer, + TokenizedDatasetLoader, +) +from .tensor_specs import ( + Binary, + BinaryDiscreteTensorSpec, + Bounded, + BoundedContinuous, + BoundedTensorSpec, + Categorical, + Choice, + Composite, + CompositeSpec, + DEVICE_TYPING, + DiscreteTensorSpec, + LazyStackedCompositeSpec, + LazyStackedTensorSpec, + MultiCategorical, + MultiDiscreteTensorSpec, + MultiOneHot, + MultiOneHotDiscreteTensorSpec, + NonTensor, + NonTensorSpec, + OneHot, + OneHotDiscreteTensorSpec, + Stacked, + StackedComposite, + TensorSpec, + Unbounded, + UnboundedContinuous, + UnboundedContinuousTensorSpec, + UnboundedDiscrete, + UnboundedDiscreteTensorSpec, +) +from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec + +__all__ = [ + "BinaryToDecimal", + "HashToInt", + "MCTSForest", + "QueryModule", + "RandomProjectionHash", + "SipHash", + "TensorDictMap", + "TensorMap", + "Tree", + "MultiStep", + "Flat2TED", + "FlatStorageCheckpointer", + "H5Combine", + "H5Split", + "H5StorageCheckpointer", + "ImmutableDatasetWriter", + "LazyMemmapStorage", + "LazyStackStorage", + "LazyTensorStorage", + "ListStorage", + "ListStorageCheckpointer", + "Nested2TED", + "NestedStorageCheckpointer", + "PrioritizedReplayBuffer", + "PrioritizedSampler", + "PrioritizedSliceSampler", + "RandomSampler", + "RemoteTensorDictReplayBuffer", + "ReplayBuffer", + "ReplayBufferEnsemble", + "RoundRobinWriter", + "SamplerEnsemble", + "SamplerWithoutReplacement", + "SliceSampler", + "SliceSamplerWithoutReplacement", + "Storage", + "StorageCheckpointerBase", + "StorageEnsemble", + "StorageEnsembleCheckpointer", + "TED2Flat", + "TED2Nested", + "TensorDictMaxValueWriter", + "TensorDictPrioritizedReplayBuffer", + "TensorDictReplayBuffer", + "TensorDictRoundRobinWriter", + "TensorStorage", + "TensorStorageCheckpointer", + "Writer", + "WriterEnsemble", + "AdaptiveKLController", + "ConstantKLController", + "create_infinite_iterator", + "get_dataloader", + "PairwiseDataset", + "PromptData", + "PromptTensorDictTokenizer", + "RewardData", + "RolloutFromModel", + "TensorDictTokenizer", + "TokenizedDatasetLoader", + "Binary", + "BinaryDiscreteTensorSpec", + "Bounded", + "BoundedContinuous", + "BoundedTensorSpec", + "Categorical", + "Choice", + "Composite", + "CompositeSpec", + "DEVICE_TYPING", + "DiscreteTensorSpec", + "LazyStackedCompositeSpec", + "LazyStackedTensorSpec", + "MultiCategorical", + "MultiDiscreteTensorSpec", + "MultiOneHot", + "MultiOneHotDiscreteTensorSpec", + "NonTensor", + "NonTensorSpec", + "OneHot", + "OneHotDiscreteTensorSpec", + "Stacked", + "StackedComposite", + "TensorSpec", + "Unbounded", + "UnboundedContinuous", + "UnboundedContinuousTensorSpec", + "UnboundedDiscrete", + "UnboundedDiscreteTensorSpec", + "check_no_exclusive_keys", + "consolidate_spec", + "contains_lazy_spec", +] diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 945a377255e..6706242fd24 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -10,7 +10,14 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded +from torchrl.data.tensor_specs import ( + Binary, + Bounded, + Categorical, + Composite, + NonTensor, + Unbounded, +) from torchrl.envs import EnvBase from torchrl.envs.common import _EnvPostInit from torchrl.envs.utils import _classproperty diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index ceb0953a11e..dd70a8c2598 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -11,12 +11,12 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _zip_strict from torch.utils.data import DataLoader -from torchrl.data import ( +from torchrl.data.map.hash import SipHash +from torchrl.data.tensor_specs import ( Bounded, Categorical as CategoricalSpec, Composite, NonTensor, - SipHash, TensorSpec, Unbounded, ) diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 4d1bbe60d43..e8ae1c7bf79 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -12,7 +12,7 @@ import numpy as np import torch from tensordict import TensorDictBase -from torchrl.data import Composite +from torchrl.data.tensor_specs import Composite from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.utils import _classproperty, make_composite_from_td diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 6ef2dc8f130..ce8e8e193c8 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -10,7 +10,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data import Categorical, Composite, TensorSpec +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 831635f08cd..317fbd8531a 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -6,7 +6,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase From a1377321ee1125c50e6113dbdc689a771ef5b3ec Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 10:04:14 +0000 Subject: [PATCH 06/10] Update [ghstack-poisoned] --- check_future_imports.py | 33 +++++++++++ test/conftest.py | 2 + test/test_exploration.py | 3 +- test/test_transforms.py | 2 +- torchrl/_extension.py | 1 + .../collectors/distributed/default_configs.py | 6 ++ torchrl/collectors/distributed/utils.py | 7 +++ torchrl/data/datasets/d4rl_infos.py | 1 + torchrl/data/datasets/utils.py | 2 + torchrl/envs/custom/pendulum.py | 2 + torchrl/envs/libs/brax.py | 21 +++---- torchrl/envs/libs/habitat.py | 2 + torchrl/envs/libs/jax_utils.py | 9 +-- torchrl/envs/libs/openml.py | 2 + torchrl/envs/libs/robohive.py | 6 +- torchrl/envs/libs/utils.py | 5 +- torchrl/envs/model_based/common.py | 10 ++-- torchrl/envs/model_based/dreamer.py | 8 +-- torchrl/envs/transforms/functional.py | 4 +- torchrl/envs/transforms/gym_transforms.py | 5 +- torchrl/envs/transforms/r3m.py | 14 ++--- torchrl/envs/transforms/utils.py | 2 +- torchrl/envs/transforms/vc1.py | 5 +- torchrl/envs/transforms/vip.py | 15 ++--- torchrl/envs/vec_envs.py | 2 + torchrl/modules/distributions/discrete.py | 38 +++++++------ .../modules/distributions/truncated_normal.py | 4 +- torchrl/modules/distributions/utils.py | 4 +- torchrl/modules/models/batchrenorm.py | 2 + torchrl/modules/models/model_based.py | 2 + torchrl/modules/models/recipes/impala.py | 1 + torchrl/modules/models/rlhf.py | 2 + torchrl/modules/planners/cem.py | 1 + torchrl/modules/planners/common.py | 5 +- torchrl/modules/planners/mppi.py | 1 + .../modules/tensordict_module/world_models.py | 2 +- torchrl/modules/utils/mappings.py | 1 + torchrl/modules/utils/utils.py | 1 + torchrl/objectives/value/pg.py | 6 -- torchrl/objectives/value/utils.py | 6 +- torchrl/record/loggers/common.py | 7 ++- torchrl/record/loggers/mlflow.py | 12 ++-- torchrl/record/loggers/tensorboard.py | 8 ++- torchrl/record/loggers/utils.py | 2 +- torchrl/record/loggers/wandb.py | 10 ++-- torchrl/trainers/helpers/collectors.py | 55 ++++++++++--------- torchrl/trainers/helpers/envs.py | 38 +++++++------ torchrl/trainers/helpers/logger.py | 1 + torchrl/trainers/helpers/losses.py | 9 +-- torchrl/trainers/helpers/models.py | 6 +- torchrl/trainers/helpers/replay_buffer.py | 7 ++- torchrl/trainers/helpers/trainers.py | 20 +++---- 52 files changed, 256 insertions(+), 164 deletions(-) create mode 100644 check_future_imports.py delete mode 100644 torchrl/objectives/value/pg.py diff --git a/check_future_imports.py b/check_future_imports.py new file mode 100644 index 00000000000..796002993eb --- /dev/null +++ b/check_future_imports.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + + +def find_files_without_future_annotations(directory): + """Finds Python files that do not contain 'from __future__ import annotations'.""" + files_without_annotations = [] + + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + with open(file_path, encoding="utf-8") as f: + content = f.read() + if "from __future__ import annotations" not in content: + files_without_annotations.append(file_path) + + return files_without_annotations + + +if __name__ == "__main__": + repo_directory = "." # Change this to your repository's root directory + files = find_files_without_future_annotations(repo_directory) + if files: + print("Files without 'from __future__ import annotations':") + for file in files: + print(file) + else: + print("All files contain 'from __future__ import annotations'.") diff --git a/test/conftest.py b/test/conftest.py index f2648a18041..ba49735b98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import os import sys diff --git a/test/test_exploration.py b/test/test_exploration.py index 554ebdc4aca..847cde926d1 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -379,9 +379,8 @@ def test_nested( ) action_spec = env.action_spec - d_act = action_spec.shape[-1] + action_spec.shape[-1] - nn.LazyLinear(d_act).to(device) policy = TensorDictModule( CountingEnvCountModule(action_spec=action_spec), in_keys=[("data", "states") if nested_obs_action else "observation"], diff --git a/test/test_transforms.py b/test/test_transforms.py index ffa81fde061..d1c2947f7af 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4785,7 +4785,7 @@ def make_env(): ) return env - SerialEnv(2, make_env) + SerialEnv(2, make_env).check_env_specs() def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): diff --git a/torchrl/_extension.py b/torchrl/_extension.py index 61eedb46418..d84d73cca4a 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import importlib.util import warnings diff --git a/torchrl/collectors/distributed/default_configs.py b/torchrl/collectors/distributed/default_configs.py index edcaf6d91e4..8da69010242 100644 --- a/torchrl/collectors/distributed/default_configs.py +++ b/torchrl/collectors/distributed/default_configs.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os TCP_PORT = os.environ.get("TCP_PORT", "10003") diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 2dd6fcf6c93..bc72bda6a4a 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -1,3 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + import subprocess import time diff --git a/torchrl/data/datasets/d4rl_infos.py b/torchrl/data/datasets/d4rl_infos.py index e9790ea04f9..c3e0a743f35 100644 --- a/torchrl/data/datasets/d4rl_infos.py +++ b/torchrl/data/datasets/d4rl_infos.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations D4RL_DATASETS = { "maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", diff --git a/torchrl/data/datasets/utils.py b/torchrl/data/datasets/utils.py index b88e3aee14e..975d69746ca 100644 --- a/torchrl/data/datasets/utils.py +++ b/torchrl/data/datasets/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index b530a01418e..bf99a768a61 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import numpy as np import torch diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index d6e4db8b0e8..ea080e20dfc 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -2,10 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import warnings -from typing import Dict, Optional, Union import torch from packaging import version @@ -209,7 +210,7 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): f"Setting a device in Brax wrapped environments is strongly recommended." ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): brax = self.lib if version.parse(brax.__version__) < version.parse("0.10.4"): raise ImportError("Brax v0.10.4 or greater is required.") @@ -223,12 +224,12 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, requires_grad: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -241,7 +242,7 @@ def _build_env( ) return env - def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 + def _make_state_spec(self, env: brax.envs.env.Env): # noqa: F821 jax = self.jax key = jax.random.PRNGKey(0) @@ -250,7 +251,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 state_spec = _extract_spec(state_dict).expand(self.batch_size) return state_spec - def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 + def _make_specs(self, env: brax.envs.env.Env) -> None: # noqa: F821 self.action_spec = Bounded( low=-1, high=1, @@ -291,7 +292,7 @@ def _make_state_example(self): state = _tree_reshape(state, self.batch_size) return state - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: jax = self.jax self._key = None self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) @@ -551,7 +552,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "brax.envs.env.Env": # noqa: F821 + ) -> brax.envs.env.Env: # noqa: F821 if not _has_brax: raise ImportError( f"brax not found, unable to create {env_name}. " @@ -576,7 +577,7 @@ def _build_env( def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 999277a2db8..380ffa3697e 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import importlib.util diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 086533cb487..6b3d7520ba7 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import dataclasses import importlib.util -from typing import Union # import jax import numpy as np @@ -41,7 +42,7 @@ def _tree_flatten(x, batch_size: torch.Size): def _ndarray_to_tensor( - value: Union["jnp.ndarray", np.ndarray] # noqa: F821 + value: jnp.ndarray | np.ndarray # noqa: F821 ) -> torch.Tensor: from jax import dlpack as jax_dlpack, numpy as jnp @@ -59,7 +60,7 @@ def _ndarray_to_tensor( return out.to(numpy_to_torch_dtype_dict[value.dtype]) -def _tensor_to_ndarray(value: torch.Tensor) -> "jnp.ndarray": # noqa: F821 +def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821 from jax import dlpack as jax_dlpack return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous())) @@ -148,7 +149,7 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size return type(object_example)(**t) -def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> TensorSpec: +def _extract_spec(data: torch.Tensor | TensorDictBase, key=None) -> TensorSpec: if isinstance(data, torch.Tensor): shape = data.shape if key in ("reward", "done"): diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 317fbd8531a..e4a2ed2c828 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import torch diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index c84450c7e94..fd02664a2a0 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib import os import warnings @@ -162,7 +164,7 @@ def _build_env( # noqa: F811 pixels_only: bool = False, from_depths: bool = False, **kwargs, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 if from_pixels: if "cameras" not in kwargs: warnings.warn( @@ -219,7 +221,7 @@ def _build_env( # noqa: F811 self.set_info_dict_reader(self.read_info) return env - def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821 out = super()._make_specs(env=env, batch_size=batch_size) self.env.reset() *_, info = self.env.step(self.env.action_space.sample()) diff --git a/torchrl/envs/libs/utils.py b/torchrl/envs/libs/utils.py index d7d7dbf6bfe..a4124c8dbf6 100644 --- a/torchrl/envs/libs/utils.py +++ b/torchrl/envs/libs/utils.py @@ -2,10 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +# # Copied from gym > 0.19 release - +# # this file should only be accessed when gym is installed +from __future__ import annotations import collections import copy diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index b2b64837808..ae09391485f 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -2,10 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import abc import warnings -from typing import List, Optional import torch from tensordict import TensorDict @@ -112,10 +112,10 @@ class ModelBasedEnvBase(EnvBase): def __init__( self, world_model: TensorDictModule, - params: Optional[List[torch.Tensor]] = None, - buffers: Optional[List[torch.Tensor]] = None, + params: list[torch.Tensor] | None = None, + buffers: list[torch.Tensor] | None = None, device: DEVICE_TYPING = "cpu", - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, run_type_checks: bool = False, ): super().__init__( @@ -173,6 +173,6 @@ def _step( def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: raise NotImplementedError - def _set_seed(self, seed: Optional[int]) -> int: + def _set_seed(self, seed: int | None) -> int: warnings.warn("Set seed isn't needed for model based environments") return seed diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 67e863c4799..8c5cfe5afcc 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import Optional, Tuple import torch from tensordict import TensorDict @@ -21,11 +21,11 @@ class DreamerEnv(ModelBasedEnvBase): def __init__( self, world_model: TensorDictModule, - prior_shape: Tuple[int, ...], - belief_shape: Tuple[int, ...], + prior_shape: tuple[int, ...], + belief_shape: tuple[int, ...], obs_decoder: TensorDictModule = None, device: DEVICE_TYPING = "cpu", - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, ): super().__init__(world_model, device=device, batch_size=batch_size) self.obs_decoder = obs_decoder diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index 80bd1e089a0..ec6fbafd154 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import List from torch import Tensor @@ -18,7 +18,7 @@ def _get_image_num_channels(img: Tensor) -> int: raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") -def _assert_channels(img: Tensor, permitted: List[int]) -> None: +def _assert_channels(img: Tensor, permitted: list[int]) -> None: c = _get_image_num_channels(img) if c not in permitted: raise TypeError( diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index dea05ad175f..c908207139c 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. """Gym-specific transforms.""" + +from __future__ import annotations + import warnings import torch @@ -186,7 +189,7 @@ def transform_observation_spec(self, observation_spec): ) return observation_spec - def register_keys(self, loss_or_advantage: "torchrl.objectives.common.LossModule"): + def register_keys(self, loss_or_advantage: torchrl.objectives.common.LossModule): """Registers the end-of-life key at appropriate places within the loss. Args: diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index dd150e8d94e..63b5ed44552 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import importlib.util -from typing import List, Optional, Union import torch from tensordict import set_lazy_legacy, TensorDict, TensorDictBase @@ -232,13 +232,13 @@ def __new__(cls, *args, **kwargs): def __init__( self, model_name: str, - in_keys: List[str], - out_keys: List[str] = None, + in_keys: list[str], + out_keys: list[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 - download_path: Optional[str] = None, - tensor_pixels_keys: List[str] = None, + download: bool | WeightsEnum | str = False, # noqa: F821 + download_path: str | None = None, + tensor_pixels_keys: list[str] = None, ): super().__init__() self.in_keys = in_keys if in_keys is not None else ["pixels"] @@ -356,7 +356,7 @@ def _init(self): if self._dtype is not None: self.to(self._dtype) - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py index a1b30cb1aca..8ef96c04ce0 100644 --- a/torchrl/envs/transforms/utils.py +++ b/torchrl/envs/transforms/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations import torch from torch import nn diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 76335bd8917..592237cbae3 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib import os import subprocess from functools import partial -from typing import Union import torch from tensordict import TensorDictBase @@ -212,7 +213,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index c00be2403c4..4bfcfc9b5ce 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util -from typing import List, Optional, Union import torch from tensordict import set_lazy_legacy, TensorDict, TensorDictBase @@ -203,13 +204,13 @@ def __new__(cls, *args, **kwargs): def __init__( self, model_name: str, - in_keys: List[str] = None, - out_keys: List[str] = None, + in_keys: list[str] = None, + out_keys: list[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 - download_path: Optional[str] = None, - tensor_pixels_keys: List[str] = None, + download: bool | WeightsEnum | str = False, # noqa: F821 + download_path: str | None = None, + tensor_pixels_keys: list[str] = None, ): super().__init__() self.in_keys = in_keys if in_keys is not None else ["pixels"] @@ -325,7 +326,7 @@ def _init(self): if self._dtype is not None: self.to(self._dtype) - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: diff --git a/torchrl/envs/vec_envs.py b/torchrl/envs/vec_envs.py index 73dd159751c..e1956ccd9f7 100644 --- a/torchrl/envs/vec_envs.py +++ b/torchrl/envs/vec_envs.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings warnings.warn("vec_env.py has moved to batch_envs.py.", category=DeprecationWarning) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 9ba33806691..f3a5f081721 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,9 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from enum import Enum from functools import wraps -from typing import Any, Optional, Sequence, Union +from typing import Any, Sequence import torch import torch.distributions as D @@ -17,8 +19,8 @@ def _treat_categorical_params( - params: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: + params: torch.Tensor | None = None, +) -> torch.Tensor | None: if params is None: return None if params.shape[-1] == 1: @@ -94,8 +96,8 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs, ) -> None: @@ -127,11 +129,11 @@ def entropy(self): @_one_hot_wrapper(D.Categorical) def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence]] = None + self, sample_shape: torch.Size | Sequence | None = None ) -> torch.Tensor: ... - def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) if hasattr(self, "logits") and self.logits is not None: @@ -217,13 +219,13 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, *, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = float("-inf"), - padding_value: Optional[int] = None, + padding_value: int | None = None, ) -> None: if not ((mask is None) ^ (indices is None)): raise ValueError( @@ -261,7 +263,7 @@ def __init__( self.num_samples = num_samples def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None + self, sample_shape: torch.Size | Sequence[int] | None = None ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size() @@ -298,10 +300,10 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: @staticmethod def _mask_logits( logits: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, neg_inf: float = float("-inf"), sparse_mask: bool = False, - padding_value: Optional[int] = None, + padding_value: int | None = None, ) -> torch.Tensor: if mask is None: return logits @@ -401,12 +403,12 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = float("-inf"), - padding_value: Optional[int] = None, + padding_value: int | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, ) -> None: self.grad_method = grad_method @@ -421,7 +423,7 @@ def __init__( @_one_hot_wrapper(MaskedCategorical) def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None + self, sample_shape: torch.Size | Sequence[int] | None = None ) -> torch.Tensor: ... @@ -439,7 +441,7 @@ def mode(self) -> torch.Tensor: def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1)) - def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) if hasattr(self, "logits") and self.logits is not None: diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 4c6e14cdff2..f8d481265cb 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - - +# # from https://github.com/toshas/torch_truncnorm +from __future__ import annotations import math from numbers import Number diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index fb64aa851d3..d807c511fd6 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import Union import torch from torch import autograd, distributions as d @@ -15,7 +15,7 @@ from torch._dynamo import is_compiling as is_dynamo_compiling -def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: +def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float: if isinstance(elt, torch.Tensor): return elt.to(device) return elt diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index c5534568af7..0f266e423b2 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import torch import torch.nn as nn diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 60d4dd020ef..976f57dd5b9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings import torch diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index da92caa08dd..7173a030c5c 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch import torch.nn as nn diff --git a/torchrl/modules/models/rlhf.py b/torchrl/modules/models/rlhf.py index 48953e43a4a..8b4f01a38c6 100644 --- a/torchrl/modules/models/rlhf.py +++ b/torchrl/modules/models/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib from pathlib import Path diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index abc0e3d3f95..9739ce5e592 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch from tensordict import TensorDict, TensorDictBase diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 3d6a4961f50..35703e6cad7 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import abc -from typing import Optional import torch from tensordict import TensorDictBase @@ -52,7 +53,7 @@ def planning(self, td: TensorDictBase) -> torch.Tensor: def forward( self, tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, + tensordict_out: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: if "params" in kwargs or "vmap" in kwargs: diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 31c95650d25..e4b33ced697 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch from tensordict import TensorDict, TensorDictBase diff --git a/torchrl/modules/tensordict_module/world_models.py b/torchrl/modules/tensordict_module/world_models.py index 78384196926..ae3ea4d9a00 100644 --- a/torchrl/modules/tensordict_module/world_models.py +++ b/torchrl/modules/tensordict_module/world_models.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations from tensordict.nn import TensorDictModule, TensorDictSequential diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py index a9e3ab189d5..ebcf776f605 100644 --- a/torchrl/modules/utils/mappings.py +++ b/torchrl/modules/utils/mappings.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 9a8914aab89..cb1e66a7f98 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import warnings diff --git a/torchrl/objectives/value/pg.py b/torchrl/objectives/value/pg.py deleted file mode 100644 index d62fe90a685..00000000000 --- a/torchrl/objectives/value/pg.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# implements a function that takes a sequence of returns and multiply its by the policy log_prob to get a differentiable objective diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index c28f19e3062..2abf245a8ff 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import Union import torch @@ -212,7 +212,7 @@ def _get_num_per_traj(done): def _split_and_pad_sequence( - tensor: Union[torch.Tensor, TensorDictBase], + tensor: torch.Tensor | TensorDictBase, splits: torch.Tensor, return_mask=False, time_dim=-1, @@ -318,7 +318,7 @@ def _fill_tensor(tensor): def _inv_pad_sequence( - tensor: Union[torch.Tensor, TensorDictBase], + tensor: torch.Tensor | TensorDictBase, splits: torch.Tensor, mask: torch.Tensor = None, ): diff --git a/torchrl/record/loggers/common.py b/torchrl/record/loggers/common.py index b8325763166..e6db65eb816 100644 --- a/torchrl/record/loggers/common.py +++ b/torchrl/record/loggers/common.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import abc -from typing import Dict, Sequence, Union +from typing import Sequence from torch import Tensor @@ -21,7 +22,7 @@ def __init__(self, exp_name: str, log_dir: str) -> None: self.experiment = self._create_experiment() @abc.abstractmethod - def _create_experiment(self) -> "Experiment": # noqa: F821 + def _create_experiment(self) -> Experiment: # noqa: F821 ... @abc.abstractmethod @@ -33,7 +34,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non ... @abc.abstractmethod - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 ... @abc.abstractmethod diff --git a/torchrl/record/loggers/mlflow.py b/torchrl/record/loggers/mlflow.py index 548d8213279..e2df9f30f42 100644 --- a/torchrl/record/loggers/mlflow.py +++ b/torchrl/record/loggers/mlflow.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence from torch import Tensor @@ -34,7 +36,7 @@ def __init__( self, exp_name: str, tracking_uri: str, - tags: Optional[Dict[str, Any]] = None, + tags: dict[str, Any] | None = None, *, video_fps: int = 30, **kwargs, @@ -51,7 +53,7 @@ def __init__( self.video_log_counter = 0 self.video_fps = video_fps - def _create_experiment(self) -> "mlflow.ActiveRun": # noqa + def _create_experiment(self) -> mlflow.ActiveRun: # noqa import mlflow """Creates an mlflow experiment. @@ -70,7 +72,7 @@ def _create_experiment(self) -> "mlflow.ActiveRun": # noqa self.id = experiment.experiment_id return mlflow.start_run(experiment_id=self.id) - def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to mlflow. Args: @@ -118,7 +120,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: torchvision.io.write_video(filename=f.name, video_array=video, fps=fps) mlflow.log_artifact(f.name, "videos") - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/loggers/tensorboard.py b/torchrl/record/loggers/tensorboard.py index 5ecc9742614..39518807046 100644 --- a/torchrl/record/loggers/tensorboard.py +++ b/torchrl/record/loggers/tensorboard.py @@ -2,10 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os -from typing import Dict, Sequence, Union +from typing import Sequence from torch import Tensor @@ -31,7 +33,7 @@ def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None: self._has_imported_moviepy = False - def _create_experiment(self) -> "SummaryWriter": # noqa + def _create_experiment(self) -> SummaryWriter: # noqa """Creates a tensorboard experiment. Args: @@ -91,7 +93,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index 226135f333f..5fe443db301 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations import os import pathlib diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index c015c2b0214..3d23a485458 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os import warnings -from typing import Dict, Optional, Sequence, Union +from typing import Sequence from torch import Tensor @@ -93,7 +95,7 @@ def __init__( self.video_log_counter = 0 - def _create_experiment(self) -> "WandbLogger": + def _create_experiment(self) -> WandbLogger: """Creates a wandb experiment. Args: @@ -111,7 +113,7 @@ def _create_experiment(self) -> "WandbLogger": return wandb.init(**self._wandb_kwargs) - def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to wandb. Args: @@ -173,7 +175,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index d60773a04c2..e278ca1d4ae 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable from tensordict import TensorDictBase @@ -22,10 +23,10 @@ def sync_async_collector( - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, ) -> MultiaSyncDataCollector: """Runs asynchronous collectors, each running synchronous environments. @@ -82,12 +83,12 @@ def sync_async_collector( def sync_sync_collector( - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, -) -> Union[SyncDataCollector, MultiSyncDataCollector]: +) -> SyncDataCollector | MultiSyncDataCollector: """Runs synchronous collectors, each running synchronous environments. E.g. @@ -164,16 +165,16 @@ def sync_sync_collector( def _make_collector( - collector_class: Type, - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], + collector_class: type, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int = -1, frames_per_batch: int = 200, - total_frames: Optional[int] = None, - postproc: Optional[Callable] = None, - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + total_frames: int | None = None, + postproc: Callable | None = None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, ) -> DataCollectorBase: if env_kwargs is None: @@ -249,11 +250,11 @@ def _make_collector( def make_collector_offpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[ - TensorDictModuleWrapper, ProbabilisticTensorDictSequential - ], - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Optional[Dict] = None, + actor_model_explore: ( + TensorDictModuleWrapper | ProbabilisticTensorDictSequential + ), + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Returns a data collector for off-policy sota-implementations. @@ -313,11 +314,11 @@ def make_collector_offpolicy( def make_collector_onpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[ - TensorDictModuleWrapper, ProbabilisticTensorDictSequential - ], - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Optional[Dict] = None, + actor_model_explore: ( + TensorDictModuleWrapper | ProbabilisticTensorDictSequential + ), + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Makes a collector in on-policy settings. diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index e236b61c8e5..170a82f3414 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -2,9 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from copy import copy from dataclasses import dataclass, field as dataclass_field -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import torch @@ -44,7 +46,7 @@ } -def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 +def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames @@ -208,20 +210,20 @@ def get_norm_state_dict(env): def transformed_env_constructor( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 video_tag: str = "", - logger: Optional[Logger] = None, - stats: Optional[dict] = None, + logger: Logger | None = None, + stats: dict | None = None, norm_obs_only: bool = False, use_env_creator: bool = False, - custom_env_maker: Optional[Callable] = None, - custom_env: Optional[EnvBase] = None, + custom_env_maker: Callable | None = None, + custom_env: EnvBase | None = None, return_transformed_envs: bool = True, - action_dim_gsde: Optional[int] = None, - state_dim_gsde: Optional[int] = None, - batch_dims: Optional[int] = 0, - obs_norm_state_dict: Optional[dict] = None, -) -> Union[Callable, EnvCreator]: + action_dim_gsde: int | None = None, + state_dim_gsde: int | None = None, + batch_dims: int | None = 0, + obs_norm_state_dict: dict | None = None, +) -> Callable | EnvCreator: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -326,8 +328,8 @@ def make_transformed_env(**kwargs) -> TransformedEnv: def parallel_env_constructor( - cfg: "DictConfig", **kwargs # noqa: F821 -) -> Union[ParallelEnv, EnvCreator]: + cfg: DictConfig, **kwargs # noqa: F821 +) -> ParallelEnv | EnvCreator: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -370,9 +372,9 @@ def parallel_env_constructor( @torch.no_grad() def get_stats_random_rollout( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, - key: Optional[str] = None, + key: str | None = None, ): """Gathers stas (loc and scale) from an environment using random rollouts. @@ -450,7 +452,7 @@ def get_stats_random_rollout( def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: Union[str, Tuple[str, ...]] = None, + key: str | tuple[str, ...] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. @@ -530,7 +532,7 @@ class EnvConfig: # maximum steps per trajectory, frames per batch or any other factor in the algorithm, # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 # the actual number of frames retrieved will be 200e6. Default=1. - reward_scaling: Optional[float] = None + reward_scaling: float | None = None # scale of the reward. reward_loc: float = 0.0 # location of the reward. diff --git a/torchrl/trainers/helpers/logger.py b/torchrl/trainers/helpers/logger.py index b0b37533519..b06c3593557 100644 --- a/torchrl/trainers/helpers/logger.py +++ b/torchrl/trainers/helpers/logger.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass, field from typing import Any diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 152d7e2891f..91c8f5f8675 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate from torchrl.objectives.common import LossModule @@ -12,8 +13,8 @@ def make_target_updater( - cfg: "DictConfig", loss_module: LossModule # noqa: F821 -) -> Optional[TargetNetUpdater]: + cfg: DictConfig, loss_module: LossModule # noqa: F821 +) -> TargetNetUpdater | None: """Builds a target network weight update object.""" if cfg.loss == "double": if not cfg.hard_update: @@ -35,7 +36,7 @@ def make_target_updater( return target_net_updater -def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]: +def make_dqn_loss(model, cfg) -> tuple[DQNLoss, TargetNetUpdater | None]: """Builds the DQN loss module.""" loss_kwargs = {} if cfg.distributional: diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 6e74386ed63..543bf940031 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import itertools from dataclasses import dataclass @@ -60,7 +62,7 @@ def make_dqn_actor( - proof_environment: EnvBase, cfg: "DictConfig", device: torch.device # noqa: F821 + proof_environment: EnvBase, cfg: DictConfig, device: torch.device # noqa: F821 ) -> Actor: """DQN constructor helper function. @@ -194,7 +196,7 @@ def make_dqn_actor( @set_lazy_legacy(False) def make_dreamer( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, device: DEVICE_TYPING = "cpu", action_key: str = "action", diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index 6ccbb15a291..d0da6a02964 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional import torch from torchrl._utils import _make_ordinal_device @@ -18,7 +19,7 @@ def make_replay_buffer( - device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 + device: DEVICE_TYPING, cfg: DictConfig # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = _make_ordinal_device(torch.device(device)) @@ -52,7 +53,7 @@ class ReplayArgsConfig: # buffer size, in number of frames stored. Default=1e6 prb: bool = False # whether a Prioritized replay buffer should be used instead of a more basic circular one. - buffer_scratch_dir: Optional[str] = None + buffer_scratch_dir: str | None = None # directory where the buffer data should be stored. If none is passed, they will be placed in /tmp/ buffer_prefetch: int = 10 # prefetching queue length for the replay buffer diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 4819d9e07e8..1f05e520968 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Union from warnings import warn import torch @@ -51,7 +51,7 @@ class TrainerConfig: # Optimizer to be used. lr_scheduler: str = "cosine" # LR scheduler. - selected_keys: Optional[List] = None + selected_keys: list | None = None # a list of strings that indicate the data that should be kept from the data collector. Since storing and # retrieving information from the replay buffer does not come for free, limiting the amount of data # passed to it can improve the algorithm performance. @@ -80,14 +80,14 @@ class TrainerConfig: def make_trainer( collector: DataCollectorBase, loss_module: LossModule, - recorder: Optional[EnvBase] = None, - target_net_updater: Optional[TargetNetUpdater] = None, - policy_exploration: Optional[ - Union[TensorDictModuleWrapper, TensorDictModule] - ] = None, - replay_buffer: Optional[ReplayBuffer] = None, - logger: Optional[Logger] = None, - cfg: "DictConfig" = None, # noqa: F821 + recorder: EnvBase | None = None, + target_net_updater: TargetNetUpdater | None = None, + policy_exploration: None | ( + TensorDictModuleWrapper | TensorDictModule + ) = None, + replay_buffer: ReplayBuffer | None = None, + logger: Logger | None = None, + cfg: DictConfig = None, # noqa: F821 ) -> Trainer: """Creates a Trainer instance given its constituents. From 6bb023d47e3e28d63f87976ffc36c2738799de26 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 10:16:38 +0000 Subject: [PATCH 07/10] Update [ghstack-poisoned] --- torchrl/envs/libs/brax.py | 1 - torchrl/envs/libs/jax_utils.py | 4 +--- torchrl/envs/model_based/dreamer.py | 1 - torchrl/envs/transforms/functional.py | 1 - torchrl/modules/distributions/discrete.py | 4 +--- torchrl/modules/distributions/utils.py | 1 - torchrl/objectives/value/utils.py | 1 - torchrl/trainers/helpers/collectors.py | 8 ++------ torchrl/trainers/helpers/trainers.py | 4 +--- 9 files changed, 5 insertions(+), 20 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index ea080e20dfc..8785ed7597a 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -7,7 +7,6 @@ import importlib.util import warnings - import torch from packaging import version from tensordict import TensorDict, TensorDictBase diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 6b3d7520ba7..337e9e9a1a9 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -41,9 +41,7 @@ def _tree_flatten(x, batch_size: torch.Size): } -def _ndarray_to_tensor( - value: jnp.ndarray | np.ndarray # noqa: F821 -) -> torch.Tensor: +def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821 from jax import dlpack as jax_dlpack, numpy as jnp # JAX arrays generated by jax.vmap would have Numpy dtypes. diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 8c5cfe5afcc..9228c39aa66 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations - import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index ec6fbafd154..cd6c61c0502 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations - from torch import Tensor diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index f3a5f081721..8e9cda99b3c 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -128,9 +128,7 @@ def entropy(self): return -p_log_p.sum(-1) @_one_hot_wrapper(D.Categorical) - def sample( - self, sample_shape: torch.Size | Sequence | None = None - ) -> torch.Tensor: + def sample(self, sample_shape: torch.Size | Sequence | None = None) -> torch.Tensor: ... def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor: diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index d807c511fd6..a64d55276c3 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations - import torch from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index 2abf245a8ff..4d3c4c3b552 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations - import torch from tensordict import TensorDictBase diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index e278ca1d4ae..4f13597a8e2 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -250,9 +250,7 @@ def _make_collector( def make_collector_offpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: ( - TensorDictModuleWrapper | ProbabilisticTensorDictSequential - ), + actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), cfg: DictConfig, # noqa: F821 make_env_kwargs: dict | None = None, ) -> DataCollectorBase: @@ -314,9 +312,7 @@ def make_collector_offpolicy( def make_collector_onpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: ( - TensorDictModuleWrapper | ProbabilisticTensorDictSequential - ), + actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), cfg: DictConfig, # noqa: F821 make_env_kwargs: dict | None = None, ) -> DataCollectorBase: diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 1f05e520968..4a1e35e0e4a 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -82,9 +82,7 @@ def make_trainer( loss_module: LossModule, recorder: EnvBase | None = None, target_net_updater: TargetNetUpdater | None = None, - policy_exploration: None | ( - TensorDictModuleWrapper | TensorDictModule - ) = None, + policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None, replay_buffer: ReplayBuffer | None = None, logger: Logger | None = None, cfg: DictConfig = None, # noqa: F821 From 851247641ce467a30461dcf2af33ea5d4424a487 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 11:43:23 +0000 Subject: [PATCH 08/10] Update [ghstack-poisoned] --- .github/workflows/lint.yml | 2 +- build_tools/setup_helpers/__init__.py | 2 +- check_future_imports.py | 33 ------- docs/source/reference/envs.rst | 2 +- docs/source/reference/objectives.rst | 1 + test/test_cost.py | 11 ++- torchrl/envs/__init__.py | 127 ++++++++++++++----------- torchrl/envs/custom/__init__.py | 4 +- torchrl/envs/transforms/__init__.py | 28 ++++-- torchrl/objectives/__init__.py | 27 +++--- torchrl/objectives/common.py | 13 ++- torchrl/objectives/value/advantages.py | 13 ++- 12 files changed, 133 insertions(+), 130 deletions(-) delete mode 100644 check_future_imports.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 7b710a751fe..25ac1c49e43 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,7 +35,7 @@ jobs: echo '::endgroup::' echo '::group::Install lint tools' - pip install --progress-bar=off pre-commit + pip install --progress-bar=off pre-commit autoflake echo '::endgroup::' echo '::group::Lint Python source and configs' diff --git a/build_tools/setup_helpers/__init__.py b/build_tools/setup_helpers/__init__.py index e7cbc4a0163..52c1db79251 100644 --- a/build_tools/setup_helpers/__init__.py +++ b/build_tools/setup_helpers/__init__.py @@ -3,6 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .extension import CMakeBuild, get_ext_modules # noqa +from .extension import CMakeBuild, get_ext_modules __all__ = ["CMakeBuild", "get_ext_modules"] diff --git a/check_future_imports.py b/check_future_imports.py deleted file mode 100644 index 796002993eb..00000000000 --- a/check_future_imports.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import os - - -def find_files_without_future_annotations(directory): - """Finds Python files that do not contain 'from __future__ import annotations'.""" - files_without_annotations = [] - - for root, _, files in os.walk(directory): - for file in files: - if file.endswith(".py"): - file_path = os.path.join(root, file) - with open(file_path, encoding="utf-8") as f: - content = f.read() - if "from __future__ import annotations" not in content: - files_without_annotations.append(file_path) - - return files_without_annotations - - -if __name__ == "__main__": - repo_directory = "." # Change this to your repository's root directory - files = find_files_without_future_annotations(repo_directory) - if files: - print("Files without 'from __future__ import annotations':") - for file in files: - print(file) - else: - print("All files contain 'from __future__ import annotations'.") diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index b53ac84585d..71bc2b2219f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1220,7 +1220,7 @@ Recorders are transforms that register data as they come in, for logging purpose Helpers ------- -.. currentmodule:: torchrl.envs.utils +.. currentmodule:: torchrl.envs .. autosummary:: :toctree: generated/ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 3d88536e3d9..f2741809bd3 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -111,6 +111,7 @@ auto-completion to make their choice. :template: rl_template_noinherit.rst LossModule + add_random_module DQN --- diff --git a/test/test_cost.py b/test/test_cost.py index eda6596e8e4..3fd1fad62da 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -100,7 +100,7 @@ TD3BCLoss, TD3Loss, ) -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import add_random_module, LossModule from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss @@ -16162,6 +16162,15 @@ def _composite_log_prob(self): yield setter.unset() + def test_add_random_module(self): + class MyMod(nn.Module): + ... + + add_random_module(MyMod) + import torchrl.objectives.utils + + assert MyMod in torchrl.objectives.utils.RANDOM_MODULE_LIST + def test_standardization(self): t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) std_t0 = _standardize(t, exclude_dims=(1, 3)) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index e1edc291c00..19d0cdbae41 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( @@ -46,6 +46,8 @@ from .transforms import ( ActionDiscretizer, ActionMask, + as_nested_tensor, + as_padded_tensor, AutoResetEnv, AutoResetTransform, BatchSizeTransform, @@ -58,6 +60,7 @@ Compose, ConditionalSkip, Crop, + DataLoadingPrimer, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, @@ -116,100 +119,91 @@ check_marl_grouping, exploration_type, ExplorationType, + get_available_libraries, make_composite_from_td, MarlGroupMapType, + RandomPolicy, set_exploration_type, step_mdp, + terminated_or_truncated, ) __all__ = [ - "ParallelEnv", - "SerialEnv", - "EnvBase", - "EnvMetaData", - "make_tensordict", - "ChessEnv", - "LLMHashingEnv", - "PendulumEnv", - "TicTacToeEnv", - "env_creator", - "EnvCreator", - "get_env_metadata", - "default_info_dict_reader", - "GymLikeEnv", - "BraxEnv", - "BraxWrapper", - "DMControlEnv", - "DMControlWrapper", - "gym_backend", - "GymEnv", - "GymWrapper", - "HabitatEnv", - "IsaacGymEnv", - "IsaacGymWrapper", - "JumanjiEnv", - "JumanjiWrapper", - "MeltingpotEnv", - "MeltingpotWrapper", - "MOGymEnv", - "MOGymWrapper", - "MultiThreadedEnv", - "MultiThreadedEnvWrapper", - "OpenMLEnv", - "OpenSpielEnv", - "OpenSpielWrapper", - "PettingZooEnv", - "PettingZooWrapper", - "register_gym_spec_conversion", - "RoboHiveEnv", - "set_gym_backend", - "SMACv2Env", - "SMACv2Wrapper", - "UnityMLAgentsEnv", - "UnityMLAgentsWrapper", - "VmasEnv", - "VmasWrapper", - "DreamerDecoder", - "DreamerEnv", - "ModelBasedEnvBase", "ActionDiscretizer", "ActionMask", "AutoResetEnv", "AutoResetTransform", "BatchSizeTransform", "BinarizeReward", + "BraxEnv", + "BraxWrapper", "BurnInTransform", "CatFrames", "CatTensors", "CenterCrop", + "ChessEnv", "ClipTransform", "Compose", "ConditionalSkip", "Crop", + "DMControlEnv", + "DMControlWrapper", + "DTypeCastTransform", + "DataLoadingPrimer", "DeviceCastTransform", "DiscreteActionProjection", "DoubleToFloat", - "DTypeCastTransform", + "DreamerDecoder", + "DreamerEnv", "EndOfLifeTransform", + "EnvBase", + "EnvCreator", + "EnvMetaData", "ExcludeTransform", + "ExplorationType", "FiniteTensorDictCheck", "FlattenObservation", "FrameSkipTransform", "GrayScale", - "gSDENoise", + "GymEnv", + "GymLikeEnv", + "GymWrapper", + "HabitatEnv", "Hash", "InitTracker", + "IsaacGymEnv", + "IsaacGymWrapper", + "JumanjiEnv", + "JumanjiWrapper", "KLRewardTransform", + "LLMEnv", + "LLMHashingEnv", "LineariseRewards", + "MOGymEnv", + "MOGymWrapper", + "MarlGroupMapType", + "MeltingpotEnv", + "MeltingpotWrapper", + "ModelBasedEnvBase", "MultiAction", "MultiStepTransform", + "MultiThreadedEnv", + "MultiThreadedEnvWrapper", "NoopResetEnv", "ObservationNorm", "ObservationTransform", + "OpenMLEnv", + "OpenSpielEnv", + "OpenSpielWrapper", + "ParallelEnv", + "PendulumEnv", "PermuteTransform", + "PettingZooEnv", + "PettingZooWrapper", "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", + "RandomPolicy", "RemoveEmptySpecs", "RenameTransform", "Resize", @@ -217,33 +211,52 @@ "RewardClipping", "RewardScaling", "RewardSum", + "RoboHiveEnv", + "SMACv2Env", + "SMACv2Wrapper", "SelectTransform", + "SerialEnv", "SignTransform", "SqueezeTransform", "Stack", "StepCounter", "TargetReturn", "TensorDictPrimer", + "TicTacToeEnv", "TimeMaxPool", "Timer", - "Tokenizer", "ToTensorImage", + "Tokenizer", "TrajCounter", "Transform", "TransformedEnv", "UnaryTransform", + "UnityMLAgentsEnv", + "UnityMLAgentsWrapper", "UnsqueezeTransform", "VC1Transform", - "VecGymEnvTransform", - "VecNorm", "VIPRewardTransform", "VIPTransform", + "VecGymEnvTransform", + "VecNorm", + "VmasEnv", + "VmasWrapper", + "as_nested_tensor", + "as_padded_tensor", "check_env_specs", "check_marl_grouping", + "default_info_dict_reader", + "env_creator", "exploration_type", - "ExplorationType", + "gSDENoise", + "get_available_libraries", + "get_env_metadata", + "gym_backend", "make_composite_from_td", - "MarlGroupMapType", + "make_tensordict", + "register_gym_spec_conversion", "set_exploration_type", + "set_gym_backend", "step_mdp", + "terminated_or_truncated", ] diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 9c98af1644a..24ffee4b3f1 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. from .chess import ChessEnv -from .llm import LLMHashingEnv +from .llm import LLMEnv, LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv -__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"] +__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv", "LLMEnv"] diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 3ba694a3fce..c6814f8745d 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -6,7 +6,12 @@ from .gym_transforms import EndOfLifeTransform from .r3m import R3MTransform from .rb_transforms import MultiStepTransform -from .rlhf import KLRewardTransform +from .rlhf import ( + as_nested_tensor, + as_padded_tensor, + DataLoadingPrimer, + KLRewardTransform, +) from .transforms import ( ActionDiscretizer, ActionMask, @@ -72,10 +77,6 @@ from .vip import VIPRewardTransform, VIPTransform __all__ = [ - "EndOfLifeTransform", - "R3MTransform", - "MultiStepTransform", - "KLRewardTransform", "ActionDiscretizer", "ActionMask", "AutoResetEnv", @@ -90,25 +91,29 @@ "Compose", "ConditionalSkip", "Crop", + "DTypeCastTransform", + "DataLoadingPrimer", "DeviceCastTransform", "DiscreteActionProjection", "DoubleToFloat", - "DTypeCastTransform", + "EndOfLifeTransform", "ExcludeTransform", "FiniteTensorDictCheck", "FlattenObservation", "FrameSkipTransform", "GrayScale", - "gSDENoise", "Hash", "InitTracker", + "KLRewardTransform", "LineariseRewards", "MultiAction", + "MultiStepTransform", "NoopResetEnv", "ObservationNorm", "ObservationTransform", "PermuteTransform", "PinMemoryTransform", + "R3MTransform", "RandomCropTensorDict", "RemoveEmptySpecs", "RenameTransform", @@ -126,16 +131,19 @@ "TensorDictPrimer", "TimeMaxPool", "Timer", - "Tokenizer", "ToTensorImage", + "Tokenizer", "TrajCounter", "Transform", "TransformedEnv", "UnaryTransform", "UnsqueezeTransform", - "VecGymEnvTransform", - "VecNorm", "VC1Transform", "VIPRewardTransform", "VIPTransform", + "VecGymEnvTransform", + "VecNorm", + "as_nested_tensor", + "as_padded_tensor", + "gSDENoise", ] diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 31ee96c5bfb..da198d30c26 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -35,39 +35,40 @@ __all__ = [ "A2CLoss", - "LossModule", "CQLLoss", - "DiscreteCQLLoss", + "ClipPPOLoss", "CrossQLoss", "DDPGLoss", + "DQNLoss", "DTLoss", - "OnlineDTLoss", + "DiscreteCQLLoss", + "DiscreteIQLLoss", + "DiscreteSACLoss", "DistributionalDQNLoss", - "DQNLoss", "DreamerActorLoss", "DreamerModelLoss", "DreamerValueLoss", "GAILLoss", - "DiscreteIQLLoss", + "HardUpdate", "IQLLoss", - "QMixerLoss", - "ClipPPOLoss", "KLPENPPOLoss", + "LossModule", + "OnlineDTLoss", "PPOLoss", + "QMixerLoss", "REDQLoss", "ReinforceLoss", - "DiscreteSACLoss", "SACLoss", - "TD3Loss", + "SoftUpdate", "TD3BCLoss", + "TD3Loss", + "TargetNetUpdater", + "ValueEstimators", + "add_random_module", "default_value_kwargs", "distance_loss", "group_optimizers", - "HardUpdate", "hold_out_net", "hold_out_params", "next_state_value", - "SoftUpdate", - "TargetNetUpdater", - "ValueEstimators", ] diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index f74bcc3b36f..536e95265c7 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -633,6 +633,8 @@ def vmap_randomness(self): """ if self._vmap_randomness is None: + import torchrl.objectives.utils + main_modules = list(self.__dict__.values()) + list(self.children()) modules = ( module @@ -641,7 +643,7 @@ def vmap_randomness(self): for module in main_module.modules() ) for val in modules: - if isinstance(val, RANDOM_MODULE_LIST): + if isinstance(val, torchrl.objectives.utils.RANDOM_MODULE_LIST): self._vmap_randomness = "different" break else: @@ -685,7 +687,10 @@ def __call__(self, x): return x -def add_ramdom_module(module): +def add_random_module(module): """Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random.""" - global RANDOM_MODULE_LIST - RANDOM_MODULE_LIST = RANDOM_MODULE_LIST + (module,) + import torchrl.objectives.utils + + torchrl.objectives.utils.RANDOM_MODULE_LIST = ( + torchrl.objectives.utils.RANDOM_MODULE_LIST + (module,) + ) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 4f12809809e..8194b72ee3e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -29,12 +29,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import ( - _maybe_get_or_select, - _vmap_func, - hold_out_net, - RANDOM_MODULE_LIST, -) +from torchrl.objectives.utils import _maybe_get_or_select, _vmap_func, hold_out_net from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -388,8 +383,12 @@ def vmap_randomness(self): do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): + import torchrl.objectives.utils + for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): + if isinstance( + module, torchrl.objectives.utils.RANDOM_MODULE_LIST + ): self._vmap_randomness = "different" do_break = True break From b1a9c4455cc2211dddb67e578fafa42378fceb46 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 13:26:20 +0000 Subject: [PATCH 09/10] Update [ghstack-poisoned] --- setup.cfg | 4 ++++ test/test_helpers.py | 2 ++ torchrl/objectives/common.py | 2 +- torchrl/trainers/helpers/envs.py | 34 ++++++++++++++++++-------------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/setup.cfg b/setup.cfg index 985c68e5af9..0649a97497f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,3 +45,7 @@ ignore-decorators = test_* ; test/*.py ; .circleci/* + +[autoflake] +per-file-ignores = + torchrl/trainers/helpers/envs.py * diff --git a/test/test_helpers.py b/test/test_helpers.py index cf1160f1bb2..3ba2326254d 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import argparse import dataclasses @@ -225,6 +226,7 @@ def test_timeit(): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) def test_transformed_env_constructor_with_state_dict(from_pixels): + config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 536e95265c7..50fc7ee7fba 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -21,7 +21,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import set_recurrent_mode +from torchrl.modules.tensordict_module.rnn import set_recurrent_mode from torchrl.objectives.utils import ValueEstimators from torchrl.objectives.value import ValueEstimatorBase diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 170a82f3414..af74877ed52 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -2,13 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from __future__ import annotations + +# This makes omegaconf unhappy with typing.Any +# Therefore we need Optional and Union +# from __future__ import annotations from copy import copy from dataclasses import dataclass, field as dataclass_field -from typing import Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence, Union import torch +from omegaconf import DictConfig from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs import ParallelEnv @@ -212,18 +216,18 @@ def get_norm_state_dict(env): def transformed_env_constructor( cfg: DictConfig, # noqa: F821 video_tag: str = "", - logger: Logger | None = None, - stats: dict | None = None, + logger: Optional[Logger] = None, # noqa + stats: Optional[dict] = None, norm_obs_only: bool = False, use_env_creator: bool = False, - custom_env_maker: Callable | None = None, - custom_env: EnvBase | None = None, + custom_env_maker: Optional[Callable] = None, + custom_env: Optional[EnvBase] = None, return_transformed_envs: bool = True, - action_dim_gsde: int | None = None, - state_dim_gsde: int | None = None, - batch_dims: int | None = 0, - obs_norm_state_dict: dict | None = None, -) -> Callable | EnvCreator: + action_dim_gsde: Optional[int] = None, + state_dim_gsde: Optional[int] = None, + batch_dims: Optional[int] = 0, + obs_norm_state_dict: Optional[dict] = None, +) -> Union[Callable, EnvCreator]: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -329,7 +333,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: def parallel_env_constructor( cfg: DictConfig, **kwargs # noqa: F821 -) -> ParallelEnv | EnvCreator: +) -> Union[ParallelEnv, EnvCreator]: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -374,7 +378,7 @@ def parallel_env_constructor( def get_stats_random_rollout( cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, - key: str | None = None, + key: Optional[str] = None, ): """Gathers stas (loc and scale) from an environment using random rollouts. @@ -452,7 +456,7 @@ def get_stats_random_rollout( def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: str | tuple[str, ...] = None, + key: Union[str, tuple[str, ...]] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. @@ -532,7 +536,7 @@ class EnvConfig: # maximum steps per trajectory, frames per batch or any other factor in the algorithm, # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 # the actual number of frames retrieved will be 200e6. Default=1. - reward_scaling: float | None = None + reward_scaling: Any = None # noqa # scale of the reward. reward_loc: float = 0.0 # location of the reward. From 688d12b286131fe2434366ae02bc2662101c10c1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 15:02:59 +0000 Subject: [PATCH 10/10] Update [ghstack-poisoned] --- examples/rlhf/models/actor_critic.py | 2 + torchrl/envs/transforms/gym_transforms.py | 5 ++- torchrl/envs/transforms/transforms.py | 3 +- torchrl/modules/tensordict_module/__init__.py | 14 +++---- torchrl/modules/tensordict_module/rnn.py | 19 ++++++--- torchrl/objectives/__init__.py | 40 ++++++++++--------- torchrl/objectives/cql.py | 2 +- torchrl/trainers/helpers/envs.py | 11 ++++- 8 files changed, 60 insertions(+), 36 deletions(-) diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py index 3de34d55166..b5be188fbd9 100644 --- a/examples/rlhf/models/actor_critic.py +++ b/examples/rlhf/models/actor_critic.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator from torchrl.modules.tensordict_module.common import VmapModule diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index c908207139c..487329942c6 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -10,7 +10,6 @@ import warnings import torch -import torchrl.objectives.common from tensordict import TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torchrl.data.tensor_specs import Unbounded @@ -189,7 +188,9 @@ def transform_observation_spec(self, observation_spec): ) return observation_spec - def register_keys(self, loss_or_advantage: torchrl.objectives.common.LossModule): + def register_keys( + self, loss_or_advantage: torchrl.objectives.common.LossModule # noqa + ): """Registers the end-of-life key at appropriate places within the loss. Args: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4960e880f67..eff19ef1b61 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -87,7 +87,6 @@ make_composite_from_td, step_mdp, ) -from torchrl.objectives.value.functional import reward2go _has_tv = importlib.util.find_spec("torchvision", None) is not None @@ -8539,6 +8538,8 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: def _inv_apply_transform( self, reward: torch.Tensor, done: torch.Tensor ) -> torch.Tensor: + from torchrl.objectives.value.functional import reward2go + return reward2go(reward, done, self.gamma) def set_container(self, container): diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index fadf4035ef5..add36202bba 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .actors import ( +from torchrl.modules.tensordict_module.actors import ( Actor, ActorCriticOperator, ActorCriticWrapper, @@ -21,8 +21,8 @@ TanhModule, ValueOperator, ) -from .common import SafeModule, VmapModule -from .exploration import ( +from torchrl.modules.tensordict_module.common import SafeModule, VmapModule +from torchrl.modules.tensordict_module.exploration import ( AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, @@ -30,11 +30,11 @@ OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) -from .probabilistic import ( +from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import ( +from torchrl.modules.tensordict_module.rnn import ( GRU, GRUCell, GRUModule, @@ -44,8 +44,8 @@ recurrent_mode, set_recurrent_mode, ) -from .sequence import SafeSequential -from .world_models import WorldModelWrapper +from torchrl.modules.tensordict_module.sequence import SafeSequential +from torchrl.modules.tensordict_module.world_models import WorldModelWrapper __all__ = [ "Actor", diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index cf6b7ad5533..598c8026578 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -19,11 +19,6 @@ from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded -from torchrl.objectives.value.functional import ( - _inv_pad_sequence, - _split_and_pad_sequence, -) -from torchrl.objectives.value.utils import _get_num_per_traj_init class LSTMCell(RNNCellBase): @@ -718,6 +713,11 @@ def set_recurrent_mode(self, mode: bool = True): @dispatch def forward(self, tensordict: TensorDictBase): + from torchrl.objectives.value.functional import ( + _inv_pad_sequence, + _split_and_pad_sequence, + ) + # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None, None] shape = tensordict.shape @@ -742,6 +742,8 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): + from torchrl.objectives.value.utils import _get_num_per_traj_init + # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -1533,6 +1535,11 @@ def set_recurrent_mode(self, mode: bool = True): @dispatch @set_lazy_legacy(False) def forward(self, tensordict: TensorDictBase): + from torchrl.objectives.value.functional import ( + _inv_pad_sequence, + _split_and_pad_sequence, + ) + # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None] shape = tensordict.shape @@ -1557,6 +1564,8 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): + from torchrl.objectives.value.utils import _get_num_per_traj_init + # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index da198d30c26..fd7ac06048b 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -3,24 +3,28 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .a2c import A2CLoss -from .common import LossModule -from .cql import CQLLoss, DiscreteCQLLoss -from .crossq import CrossQLoss -from .ddpg import DDPGLoss -from .decision_transformer import DTLoss, OnlineDTLoss -from .dqn import DistributionalDQNLoss, DQNLoss -from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss -from .gail import GAILLoss -from .iql import DiscreteIQLLoss, IQLLoss -from .multiagent import QMixerLoss -from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss -from .redq import REDQLoss -from .reinforce import ReinforceLoss -from .sac import DiscreteSACLoss, SACLoss -from .td3 import TD3Loss -from .td3_bc import TD3BCLoss -from .utils import ( +from torchrl.objectives.a2c import A2CLoss +from torchrl.objectives.common import LossModule +from torchrl.objectives.cql import CQLLoss, DiscreteCQLLoss +from torchrl.objectives.crossq import CrossQLoss +from torchrl.objectives.ddpg import DDPGLoss +from torchrl.objectives.decision_transformer import DTLoss, OnlineDTLoss +from torchrl.objectives.dqn import DistributionalDQNLoss, DQNLoss +from torchrl.objectives.dreamer import ( + DreamerActorLoss, + DreamerModelLoss, + DreamerValueLoss, +) +from torchrl.objectives.gail import GAILLoss +from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss +from torchrl.objectives.multiagent import QMixerLoss +from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss +from torchrl.objectives.redq import REDQLoss +from torchrl.objectives.reinforce import ReinforceLoss +from torchrl.objectives.sac import DiscreteSACLoss, SACLoss +from torchrl.objectives.td3 import TD3Loss +from torchrl.objectives.td3_bc import TD3BCLoss +from torchrl.objectives.utils import ( default_value_kwargs, distance_loss, group_optimizers, diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 3ca291c09ae..d94728985a2 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -20,7 +20,7 @@ from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ProbabilisticActor, QValueActor +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index af74877ed52..32965742d66 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -7,13 +7,12 @@ # Therefore we need Optional and Union # from __future__ import annotations +import importlib.util from copy import copy from dataclasses import dataclass, field as dataclass_field from typing import Any, Callable, Optional, Sequence, Union import torch -from omegaconf import DictConfig - from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase @@ -49,6 +48,14 @@ "dm_control": DMControlEnv, } +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None +if _has_omegaconf: + from omegaconf import DictConfig +else: + + class DictConfig: # noqa + ... + def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.