Skip to content

Commit 40b147e

Browse files
committed
[Lint] pyupgrade
ghstack-source-id: dcdf51db31b8f6bcfad7fd4dc53f5b5ad8098c5d Pull Request resolved: #2819
1 parent 433d0e6 commit 40b147e

File tree

217 files changed

+2011
-1440
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

217 files changed

+2011
-1440
lines changed

.github/unittest/helpers/coverage_run_parallel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None:
2828
argv: Arguments passed to this script, which need to be converted to config file entries
2929
"""
3030
assert not config_path.exists(), "Temporary coverage config exists already"
31-
cmdline = " ".join(shlex.quote(arg) for arg in argv[1:])
32-
with open(str(config_path), "wt", encoding="utf-8") as fh:
31+
cmdline = shlex.join(argv[1:])
32+
with open(str(config_path), "w", encoding="utf-8") as fh:
3333
fh.write(
3434
f"""# .coveragerc to control coverage.py
3535
[run]

.github/workflows/lint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
echo '::endgroup::'
3636
3737
echo '::group::Install lint tools'
38-
pip install --progress-bar=off pre-commit
38+
pip install --progress-bar=off pre-commit autoflake
3939
echo '::endgroup::'
4040
4141
echo '::group::Lint Python source and configs'

.pre-commit-config.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,17 @@ repos:
3535
hooks:
3636
- id: pydocstyle
3737
files: ^torchrl/
38+
39+
- repo: https://github.com/asottile/pyupgrade
40+
rev: v3.9.0
41+
hooks:
42+
- id: pyupgrade
43+
args: [--py38-plus]
44+
45+
- repo: local
46+
hooks:
47+
- id: autoflake
48+
name: autoflake
49+
entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports
50+
language: system
51+
types: [python]

build_tools/setup_helpers/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .extension import CMakeBuild, get_ext_modules # noqa
6+
from .extension import CMakeBuild, get_ext_modules
7+
8+
__all__ = ["CMakeBuild", "get_ext_modules"]

build_tools/setup_helpers/extension.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from setuptools import Extension
1515
from setuptools.command.build_ext import build_ext
1616

17-
1817
_THIS_DIR = Path(__file__).parent.resolve()
1918
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
2019
_TORCHRL_DIR = _ROOT_DIR / "torchrl"
@@ -130,7 +129,7 @@ def build_extension(self, ext):
130129
# using -j in the build_ext call, not supported by pip or PyPA-build.
131130
if hasattr(self, "parallel") and self.parallel:
132131
# CMake 3.12+ only.
133-
build_args += ["-j{}".format(self.parallel)]
132+
build_args += [f"-j{self.parallel}"]
134133

135134
if not os.path.exists(self.build_temp):
136135
os.makedirs(self.build_temp)

docs/source/reference/envs.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ Recorders are transforms that register data as they come in, for logging purpose
12201220

12211221
Helpers
12221222
-------
1223-
.. currentmodule:: torchrl.envs.utils
1223+
.. currentmodule:: torchrl.envs
12241224

12251225
.. autosummary::
12261226
:toctree: generated/

docs/source/reference/objectives.rst

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ auto-completion to make their choice.
111111
:template: rl_template_noinherit.rst
112112

113113
LossModule
114+
add_random_module
114115

115116
DQN
116117
---

examples/rlhf/models/actor_critic.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
68
from torchrl.modules.tensordict_module.common import VmapModule
79

setup.cfg

+4
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ ignore-decorators =
4545
test_*
4646
; test/*.py
4747
; .circleci/*
48+
49+
[autoflake]
50+
per-file-ignores =
51+
torchrl/trainers/helpers/envs.py *

setup.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def get_version():
3232
version_txt = os.path.join(cwd, "version.txt")
33-
with open(version_txt, "r") as f:
33+
with open(version_txt) as f:
3434
version = f.readline().strip()
3535
if os.getenv("TORCHRL_BUILD_VERSION"):
3636
version = os.getenv("TORCHRL_BUILD_VERSION")
@@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
6464
def write_version_file(version):
6565
version_path = os.path.join(cwd, "torchrl", "version.py")
6666
with open(version_path, "w") as f:
67-
f.write("__version__ = '{}'\n".format(version))
68-
f.write("git_version = {}\n".format(repr(sha)))
67+
f.write(f"__version__ = '{version}'\n")
68+
f.write(f"git_version = {repr(sha)}\n")
6969

7070

7171
def _get_pytorch_version(is_nightly, is_local):
@@ -185,7 +185,7 @@ def _main(argv):
185185
version = get_version()
186186
write_version_file(version)
187187
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
188-
logging.info("Building wheel {}-{}".format(package_name, version))
188+
logging.info(f"Building wheel {package_name}-{version}")
189189
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")
190190

191191
is_local = TORCHRL_BUILD_VERSION is None

sota-implementations/a2c/a2c_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

sota-implementations/a2c/a2c_mujoco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

sota-implementations/cql/cql_offline.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616
import hydra
1717
import numpy as np
18-
1918
import torch
2019
import tqdm
2120
from tensordict.nn import CudaGraphModule
22-
2321
from torchrl._utils import timeit
2422
from torchrl.envs.utils import ExplorationType, set_exploration_type
2523
from torchrl.objectives import group_optimizers
2624
from torchrl.record.loggers import generate_exp_name, get_logger
27-
2825
from utils import (
2926
dump_video,
3027
log_metrics,
@@ -39,7 +36,7 @@
3936

4037

4138
@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
42-
def main(cfg: "DictConfig"): # noqa: F821
39+
def main(cfg: DictConfig): # noqa: F821
4340
# Create logger
4441
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
4542
logger = None

sota-implementations/cql/cql_online.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
2524
from torchrl._utils import timeit
2625
from torchrl.envs.utils import ExplorationType, set_exploration_type
2726
from torchrl.objectives import group_optimizers
2827
from torchrl.record.loggers import generate_exp_name, get_logger
29-
3028
from utils import (
3129
dump_video,
3230
log_metrics,
@@ -42,7 +40,7 @@
4240

4341

4442
@hydra.main(version_base="1.1", config_path="", config_name="online_config")
45-
def main(cfg: "DictConfig"): # noqa: F821
43+
def main(cfg: DictConfig): # noqa: F821
4644
# Create logger
4745
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
4846
logger = None

sota-implementations/cql/discrete_cql_online.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616

1717
import hydra
1818
import numpy as np
19-
2019
import torch
2120
import torch.cuda
2221
import tqdm
2322
from tensordict.nn import CudaGraphModule
24-
2523
from torchrl._utils import timeit
26-
2724
from torchrl.envs.utils import ExplorationType, set_exploration_type
28-
2925
from torchrl.record.loggers import generate_exp_name, get_logger
3026
from utils import (
3127
log_metrics,
@@ -41,7 +37,7 @@
4137

4238

4339
@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
44-
def main(cfg: "DictConfig"): # noqa: F821
40+
def main(cfg: DictConfig): # noqa: F821
4541
device = cfg.optim.device
4642
if device in ("", None):
4743
if torch.cuda.is_available():

sota-implementations/crossq/crossq.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
import warnings
1616

1717
import hydra
18-
1918
import numpy as np
20-
2119
import torch
2220
import torch.cuda
2321
import tqdm
2422
from tensordict import TensorDict
2523
from tensordict.nn import CudaGraphModule
26-
2724
from torchrl._utils import timeit
2825
from torchrl.envs.utils import ExplorationType, set_exploration_type
2926
from torchrl.objectives import group_optimizers
30-
3127
from torchrl.record.loggers import generate_exp_name, get_logger
3228
from utils import (
3329
log_metrics,
@@ -43,7 +39,7 @@
4339

4440

4541
@hydra.main(version_base="1.1", config_path=".", config_name="config")
46-
def main(cfg: "DictConfig"): # noqa: F821
42+
def main(cfg: DictConfig): # noqa: F821
4743
device = cfg.network.device
4844
if device in ("", None):
4945
if torch.cuda.is_available():

sota-implementations/ddpg/ddpg.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
import warnings
1616

1717
import hydra
18-
1918
import numpy as np
2019
import torch
2120
import torch.cuda
2221
import tqdm
2322
from tensordict import TensorDict
2423
from tensordict.nn import CudaGraphModule
25-
2624
from torchrl._utils import timeit
27-
2825
from torchrl.envs.utils import ExplorationType, set_exploration_type
2926
from torchrl.objectives import group_optimizers
3027
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -41,7 +38,7 @@
4138

4239

4340
@hydra.main(version_base="1.1", config_path="", config_name="config")
44-
def main(cfg: "DictConfig"): # noqa: F821
41+
def main(cfg: DictConfig): # noqa: F821
4542
device = cfg.optim.device
4643
if device in ("", None):
4744
if torch.cuda.is_available():

sota-implementations/decision_transformer/dt.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
from tensordict.nn import CudaGraphModule
2020
from torchrl._utils import logger as torchrl_logger, timeit
2121
from torchrl.envs.libs.gym import set_gym_backend
22-
2322
from torchrl.envs.utils import ExplorationType, set_exploration_type
2423
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
2524
from torchrl.record import VideoRecorder
26-
2725
from utils import (
2826
dump_video,
2927
log_metrics,
@@ -37,7 +35,7 @@
3735

3836

3937
@hydra.main(config_path="", config_name="dt_config", version_base="1.1")
40-
def main(cfg: "DictConfig"): # noqa: F821
38+
def main(cfg: DictConfig): # noqa: F821
4139
set_gym_backend(cfg.env.backend).set()
4240

4341
model_device = cfg.optim.device

sota-implementations/decision_transformer/online_dt.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torchrl.envs.utils import ExplorationType, set_exploration_type
2121
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
2222
from torchrl.record import VideoRecorder
23-
2423
from utils import (
2524
dump_video,
2625
log_metrics,
@@ -34,7 +33,7 @@
3433

3534

3635
@hydra.main(config_path="", config_name="odt_config", version_base="1.1")
37-
def main(cfg: "DictConfig"): # noqa: F821
36+
def main(cfg: DictConfig): # noqa: F821
3837
set_gym_backend(cfg.env.backend).set()
3938

4039
model_device = cfg.optim.device

sota-implementations/discrete_sac/discrete_sac.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
@hydra.main(version_base="1.1", config_path="", config_name="config")
41-
def main(cfg: "DictConfig"): # noqa: F821
41+
def main(cfg: DictConfig): # noqa: F821
4242
device = cfg.network.device
4343
if device in ("", None):
4444
if torch.cuda.is_available():

sota-implementations/dqn/dqn_atari.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import tqdm
1919
from tensordict.nn import CudaGraphModule, TensorDictSequential
2020
from torchrl._utils import timeit
21-
2221
from torchrl.collectors import SyncDataCollector
2322
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
2423
from torchrl.envs import ExplorationType, set_exploration_type
@@ -32,7 +31,7 @@
3231

3332

3433
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
35-
def main(cfg: "DictConfig"): # noqa: F821
34+
def main(cfg: DictConfig): # noqa: F821
3635

3736
device = cfg.device
3837
if device in ("", None):

sota-implementations/dqn/dqn_cartpole.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn
1212
import torch.optim
1313
import tqdm
14-
1514
from tensordict.nn import CudaGraphModule, TensorDictSequential
1615
from torchrl._utils import timeit
1716
from torchrl.collectors import SyncDataCollector
@@ -27,7 +26,7 @@
2726

2827

2928
@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
30-
def main(cfg: "DictConfig"): # noqa: F821
29+
def main(cfg: DictConfig): # noqa: F821
3130

3231
device = cfg.device
3332
if device in ("", None):

sota-implementations/dreamer/dreamer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.cuda
1313
import tqdm
14+
1415
from dreamer_utils import (
1516
_default_device,
1617
dump_video,
@@ -27,7 +28,6 @@
2728
from torchrl._utils import logger as torchrl_logger, timeit
2829
from torchrl.envs.utils import ExplorationType, set_exploration_type
2930
from torchrl.modules import RSSMRollout
30-
3131
from torchrl.objectives.dreamer import (
3232
DreamerActorLoss,
3333
DreamerModelLoss,
@@ -37,7 +37,7 @@
3737

3838

3939
@hydra.main(version_base="1.1", config_path="", config_name="config")
40-
def main(cfg: "DictConfig"): # noqa: F821
40+
def main(cfg: DictConfig): # noqa: F821
4141
# cfg = correct_for_frame_skip(cfg)
4242

4343
device = _default_device(cfg.networks.device)

0 commit comments

Comments
 (0)