Skip to content

Commit b59be23

Browse files
authored
style: Bumps ruff to 0.1.9 and mypy to 1.8.0 (#226)
* chore: Bumps ruff and mypy * chore: Updates precommit * chore: Updates gitignore * style: Fixes lint * style: Fixes typing * ci: Updates CI jobs
1 parent a4e5f2e commit b59be23

22 files changed

+250
-173
lines changed

.github/collect_env.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,7 @@ def get_os(run_lambda):
217217
def get_env_info():
218218
run_lambda = run
219219

220-
if TORCHCAM_AVAILABLE:
221-
torchcam_str = torchcam.__version__
222-
else:
223-
torchcam_str = "N/A"
220+
torchcam_str = torchcam.__version__ if TORCHCAM_AVAILABLE else "N/A"
224221

225222
if TORCH_AVAILABLE:
226223
torch_str = torch.__version__
@@ -258,14 +255,14 @@ def get_env_info():
258255

259256
def pretty_str(envinfo):
260257
def replace_nones(dct, replacement="Could not collect"):
261-
for key in dct.keys():
258+
for key in dct:
262259
if dct[key] is not None:
263260
continue
264261
dct[key] = replacement
265262
return dct
266263

267264
def replace_bools(dct, true="Yes", false="No"):
268-
for key in dct.keys():
265+
for key in dct:
269266
if dct[key] is True:
270267
dct[key] = true
271268
elif dct[key] is False:

.github/verify_labels.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def parse_args():
7171
import argparse
7272

7373
parser = argparse.ArgumentParser(
74-
description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter
74+
description="PR label checker",
75+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
7576
)
7677

7778
parser.add_argument("pr", type=int, help="PR number")

.github/workflows/publish.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
steps:
5757
- uses: actions/checkout@v2
5858
- name: Miniconda setup
59-
uses: conda-incubator/setup-miniconda@v2
59+
uses: conda-incubator/setup-miniconda@v3
6060
with:
6161
auto-update-conda: true
6262
python-version: 3.9

.github/workflows/style.yml

+9-31
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@ jobs:
1515
python: [3.9]
1616
steps:
1717
- uses: actions/checkout@v2
18-
- name: Set up Python
19-
uses: actions/setup-python@v4
18+
- uses: actions/setup-python@v4
2019
with:
2120
python-version: ${{ matrix.python }}
2221
architecture: x64
2322
- name: Run ruff
2423
run: |
25-
pip install ruff==0.1.0
24+
pip install ruff==0.1.9
2625
ruff --version
2726
ruff check --diff .
2827
@@ -34,8 +33,7 @@ jobs:
3433
python: [3.9]
3534
steps:
3635
- uses: actions/checkout@v2
37-
- name: Set up Python
38-
uses: actions/setup-python@v4
36+
- uses: actions/setup-python@v4
3937
with:
4038
python-version: ${{ matrix.python }}
4139
architecture: x64
@@ -53,40 +51,20 @@ jobs:
5351
mypy --version
5452
mypy
5553
56-
black:
54+
ruff-format:
5755
runs-on: ${{ matrix.os }}
5856
strategy:
5957
matrix:
6058
os: [ubuntu-latest]
6159
python: [3.9]
6260
steps:
6361
- uses: actions/checkout@v2
64-
- name: Set up Python
65-
uses: actions/setup-python@v4
62+
- uses: actions/setup-python@v4
6663
with:
6764
python-version: ${{ matrix.python }}
6865
architecture: x64
69-
- name: Run black
70-
run: |
71-
pip install "black==23.3.0"
72-
black --version
73-
black --check --diff .
74-
75-
bandit:
76-
runs-on: ${{ matrix.os }}
77-
strategy:
78-
matrix:
79-
os: [ubuntu-latest]
80-
python: [3.9]
81-
steps:
82-
- uses: actions/checkout@v2
83-
- name: Set up Python
84-
uses: actions/setup-python@v4
85-
with:
86-
python-version: ${{ matrix.python }}
87-
architecture: x64
88-
- name: Run bandit
66+
- name: Run ruff
8967
run: |
90-
pip install bandit[toml]
91-
bandit --version
92-
bandit -r . -c pyproject.toml
68+
pip install ruff==0.1.9
69+
ruff --version
70+
ruff format --check --diff .

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ torchcam/version.py
133133

134134
# Conda distribution
135135
conda-dist/
136+
.vscode/

.pre-commit-config.yaml

+2-5
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@ repos:
1717
args: ['--branch', 'main']
1818
- id: debug-statements
1919
language_version: python3
20-
- repo: https://github.com/psf/black
21-
rev: 23.3.0
22-
hooks:
23-
- id: black
2420
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: 'v0.0.290'
21+
rev: 'v0.1.9'
2622
hooks:
2723
- id: ruff
2824
args:
2925
- --fix
26+
- id: ruff-format

Makefile

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# this target runs checks on all files
22
quality:
3+
ruff format --check .
34
ruff check .
45
mypy
5-
black --check .
6-
bandit -r . -c pyproject.toml
76

87
# this target runs checks on all files and potentially modifies some of them
98
style:
10-
black .
9+
ruff format .
1110
ruff --fix .
1211

1312
# Run tests for the library

demo/app.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717
from torchcam.methods._utils import locate_candidate_layer
1818
from torchcam.utils import overlay_mask
1919

20-
CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
20+
CAM_METHODS = [
21+
"CAM",
22+
"GradCAM",
23+
"GradCAMpp",
24+
"SmoothGradCAMpp",
25+
"ScoreCAM",
26+
"SSCAM",
27+
"ISCAM",
28+
"XGradCAM",
29+
"LayerCAM",
30+
]
2131
TV_MODELS = [
2232
"resnet18",
2333
"resnet50",
@@ -87,7 +97,8 @@ def main():
8797
)
8898
if cam_method is not None:
8999
cam_extractor = methods.__dict__[cam_method](
90-
model, target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None
100+
model,
101+
target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None,
91102
)
92103

93104
class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
@@ -103,7 +114,11 @@ def main():
103114
else:
104115
with st.spinner("Analyzing..."):
105116
# Preprocess image
106-
img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
117+
img_tensor = normalize(
118+
to_tensor(resize(img, (224, 224))),
119+
[0.485, 0.456, 0.406],
120+
[0.229, 0.224, 0.225],
121+
)
107122

108123
if torch.cuda.is_available():
109124
img_tensor = img_tensor.cuda()

docs/source/conf.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import datetime
2020
from pathlib import Path
2121

22-
sys.path.insert(0, Path().resolve().parent.parent)
22+
sys.path.insert(0, Path().cwd().parent.parent)
2323
import torchcam
2424

2525
# -- Project information -----------------------------------------------------
@@ -121,9 +121,7 @@ def add_ga_javascript(app, pagename, templatename, context, doctree):
121121
gtag('js', new Date());
122122
gtag('config', '{0}');
123123
</script>
124-
""".format(
125-
app.config.googleanalytics_id
126-
)
124+
""".format(app.config.googleanalytics_id)
127125
context["metatags"] = metatags
128126

129127

pyproject.toml

+33-23
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ test = [
5050
"pytest-pretty>=1.0.0,<2.0.0",
5151
]
5252
quality = [
53-
"ruff==0.1.0",
54-
"mypy==1.5.1",
55-
"black==23.3.0",
56-
"bandit[toml]>=1.7.0,<1.8.0",
57-
"pre-commit>=2.17.0,<3.0.0",
53+
"ruff==0.1.9",
54+
"mypy==1.8.0",
55+
"pre-commit>=3.0.0,<4.0.0",
5856
]
5957
docs = [
6058
"sphinx>=3.0.0,!=3.5.0",
@@ -80,11 +78,9 @@ dev = [
8078
"pytest-xdist>=3.0.0,<4.0.0",
8179
"pytest-pretty>=1.0.0,<2.0.0",
8280
# style
83-
"ruff==0.1.0",
84-
"mypy==1.5.1",
85-
"black==23.3.0",
86-
"bandit[toml]>=1.7.0,<1.8.0",
87-
"pre-commit>=2.17.0,<3.0.0",
81+
"ruff==0.1.9",
82+
"mypy==1.8.0",
83+
"pre-commit>=3.0.0,<4.0.0",
8884
# docs
8985
"sphinx>=3.0.0,!=3.5.0",
9086
"furo>=2022.3.4",
@@ -133,6 +129,16 @@ select = [
133129
"T20", # flake8-print
134130
"PT", # flake8-pytest-style
135131
"LOG", # flake8-logging
132+
"SIM", # flake8-simplify
133+
"YTT", # flake8-2020
134+
"ANN", # flake8-annotations
135+
"ASYNC", # flake8-async
136+
"BLE", # flake8-blind-except
137+
"A", # flake8-builtins
138+
"ICN", # flake8-import-conventions
139+
"PIE", # flake8-pie
140+
"ARG", # flake8-unused-arguments
141+
"FURB", # refurb
136142
]
137143
ignore = [
138144
"E501", # line too long, handled by black
@@ -142,20 +148,31 @@ ignore = [
142148
"F403", # star imports
143149
"E731", # lambda assignment
144150
"C416", # list comprehension to list()
151+
"ANN101", # missing type annotations on self
152+
"ANN102", # missing type annotations on cls
153+
"ANN002", # missing type annotations on *args
154+
"ANN003", # missing type annotations on **kwargs
155+
"COM812", # trailing comma missing
145156
"N812", # lowercase imported as non-lowercase
157+
"ISC001", # implicit string concatenation (handled by format)
158+
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
146159
]
147160
exclude = [".git"]
148161
line-length = 120
149162
target-version = "py39"
150163
preview = true
151164

165+
[tool.ruff.format]
166+
quote-style = "double"
167+
indent-style = "space"
168+
152169
[tool.ruff.per-file-ignores]
153170
"**/__init__.py" = ["I001", "F401", "CPY001"]
154-
"scripts/**.py" = ["D", "T201", "N812"]
155-
".github/**.py" = ["D", "T201", "S602"]
156-
"docs/**.py" = ["E402", "D103"]
157-
"tests/**.py" = ["D103", "CPY001", "S101", "PT011",]
158-
"demo/**.py" = ["D103"]
171+
"scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"]
172+
".github/**.py" = ["D", "T201", "S602", "S101", "ANN"]
173+
"docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"]
174+
"tests/**.py" = ["D103", "CPY001", "S101", "PT011", "ANN"]
175+
"demo/**.py" = ["D103", "ANN"]
159176
"setup.py" = ["T201"]
160177

161178
[tool.ruff.flake8-quotes]
@@ -177,18 +194,11 @@ no_implicit_optional = true
177194
check_untyped_defs = true
178195
implicit_reexport = false
179196
disallow_untyped_defs = true
197+
explicit_package_bases = true
180198

181199
[[tool.mypy.overrides]]
182200
module = [
183201
"PIL",
184202
"matplotlib"
185203
]
186204
ignore_missing_imports = true
187-
188-
[tool.black]
189-
line-length = 120
190-
target-version = ['py39']
191-
192-
[tool.bandit]
193-
exclude_dirs = [".github/collect_env.py"]
194-
skips = ["B101"]

scripts/cam_example.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ def main(args):
3535
p.requires_grad_(False)
3636

3737
# Image
38-
if args.img.startswith("http"):
39-
img_path = BytesIO(requests.get(args.img, timeout=5).content)
40-
else:
41-
img_path = args.img
38+
img_path = BytesIO(requests.get(args.img, timeout=5).content) if args.img.startswith("http") else args.img
4239
pil_img = Image.open(img_path, mode="r").convert("RGB")
4340

4441
# Preprocess image
45-
img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(
46-
device=device
47-
)
42+
img_tensor = normalize(
43+
to_tensor(resize(pil_img, (224, 224))),
44+
[0.485, 0.456, 0.406],
45+
[0.229, 0.224, 0.225],
46+
).to(device=device)
4847
img_tensor.requires_grad_(True)
4948

5049
if isinstance(args.method, str):
@@ -119,7 +118,8 @@ def main(args):
119118

120119
if __name__ == "__main__":
121120
parser = argparse.ArgumentParser(
122-
description="Saliency Map comparison", formatter_class=argparse.ArgumentDefaultsHelpFormatter
121+
description="Saliency Map comparison",
122+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
123123
)
124124
parser.add_argument("--arch", type=str, default="resnet18", help="Name of the architecture")
125125
parser.add_argument(
@@ -129,13 +129,23 @@ def main(args):
129129
help="The image to extract CAM from",
130130
)
131131
parser.add_argument("--class-idx", type=int, default=232, help="Index of the class to inspect")
132-
parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on")
132+
parser.add_argument(
133+
"--device",
134+
type=str,
135+
default=None,
136+
help="Default device to perform computation on",
137+
)
133138
parser.add_argument("--savefig", type=str, default=None, help="Path to save figure")
134139
parser.add_argument("--method", type=str, default=None, help="CAM method to use")
135140
parser.add_argument("--target", type=str, default=None, help="the target layer")
136141
parser.add_argument("--alpha", type=float, default=0.5, help="Transparency of the heatmap")
137142
parser.add_argument("--rows", type=int, default=1, help="Number of rows for the layout")
138-
parser.add_argument("--noblock", dest="noblock", help="Disables blocking visualization", action="store_true")
143+
parser.add_argument(
144+
"--noblock",
145+
dest="noblock",
146+
help="Disables blocking visualization",
147+
action="store_true",
148+
)
139149
args = parser.parse_args()
140150

141151
main(args)

0 commit comments

Comments
 (0)