Skip to content

Commit a74b323

Browse files
authored
Run pyfmt on .github and tools (#6165)
run `lintrunner -a --all-files --take PYFMT` Update pyproject.toml with a few options for usort that make it more similar to pytorch I can break this down into smaller PRs if this is too large
1 parent a797a36 commit a74b323

File tree

76 files changed

+943
-643
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+943
-643
lines changed

.github/scripts/benchmarks/gather_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import os
98
import json
9+
import os
1010
import time
1111
from typing import Any
1212

.github/scripts/get_tutorials_stats.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import boto3 # type: ignore[import]
1212

13+
1314
METADATA_PATH = "ossci_tutorials_stats/metadata.csv"
1415
FILENAMES_PATH = "ossci_tutorials_stats/filenames.csv"
1516

.github/scripts/update_commit_hashes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import requests
88

9+
910
UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"]
1011
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
1112

.github/scripts/upload_benchmark_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from argparse import Action, ArgumentParser, Namespace
1515
from decimal import Decimal
1616
from json.decoder import JSONDecodeError
17-
1817
from logging import info
1918
from typing import Any, Callable, Dict, List, Optional
2019
from warnings import warn
2120

2221
import boto3
2322

23+
2424
logging.basicConfig(level=logging.INFO)
2525

2626

.github/scripts/validate_scale_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
import copy
1010
import json
1111
import os
12-
1312
import urllib.request
1413
from pathlib import Path
15-
1614
from typing import Any, cast, Dict, List, NamedTuple
1715

1816
import jsonschema
19-
2017
import yaml
2118

19+
2220
MAX_AVAILABLE_MINIMUM = 50
2321

2422
# Paths relative to their respective repositories

.lintrunner.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,8 @@ include_patterns = [
266266
'**/*.py',
267267
]
268268
exclude_patterns = [
269-
'.github/scripts/**',
270269
'aws/lambda/**',
271270
's3_management/**',
272-
'tools/**',
273271
]
274272
command = [
275273
'python3',

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ requires = [
1414
# Use legacy backend to import local packages in setup.py
1515
build-backend = "setuptools.build_meta:__legacy__"
1616

17+
[tool.isort]
18+
lines_after_imports = 2
19+
multi_line_output = 3
20+
indent = 4
21+
include_trailing_comma = true
22+
combine_as_imports = true
1723

1824
[tool.black]
1925
# Uncomment if pyproject.toml worked fine to ensure consistency with flake8

tools/analytics/cubinsizes.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,30 @@
1212
try:
1313
from elftools.elf.elffile import ELFFile
1414
except ModuleNotFoundError:
15-
print(f'elftools module not found, trying to install it from pip')
15+
print(f"elftools module not found, trying to install it from pip")
1616
from pip._internal import main as pip_main
17+
1718
try:
1819
pip_main(["install", "pyelftools", "--user"])
1920
except SystemExit:
20-
print(f'PIP installation failed, please install it manually by invoking "{sys.executable} -mpip install pyelftools --user"')
21+
print(
22+
f'PIP installation failed, please install it manually by invoking "{sys.executable} -mpip install pyelftools --user"'
23+
)
2124
sys.exit(-1)
2225
from elftools.elf.elffile import ELFFile
2326

2427

2528
# From https://stackoverflow.com/questions/1094841/reusable-library-to-get-human-readable-version-of-file-size
26-
def sizeof_fmt(num, suffix='B'):
27-
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
29+
def sizeof_fmt(num, suffix="B"):
30+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
2831
if abs(num) < 1024.0:
2932
return "%3.1f%s%s" % (num, unit, suffix)
3033
num /= 1024.0
31-
return "%.1f%s%s" % (num, 'Yi', suffix)
34+
return "%.1f%s%s" % (num, "Yi", suffix)
3235

3336

34-
def compute_cubin_sizes(file_name, section_name='.nv_fatbin', debug=False):
35-
with open(file_name, 'rb') as f:
37+
def compute_cubin_sizes(file_name, section_name=".nv_fatbin", debug=False):
38+
with open(file_name, "rb") as f:
3639
elf_file = ELFFile(f)
3740
nv_fatbin = elf_file.get_section_by_name(section_name)
3841
if nv_fatbin is None:
@@ -41,20 +44,32 @@ def compute_cubin_sizes(file_name, section_name='.nv_fatbin', debug=False):
4144
idx, offs = 0, 0
4245
elf_sizes = {}
4346
while offs < len(data):
44-
(magic, version, header_size, fatbin_size) = struct.unpack('IHHL', data[offs: offs + 16])
45-
if magic != 0xba55ed50 or version != 1:
46-
raise RuntimeError(f"Unexpected fatbin magic {hex(magic)} or version {version}")
47+
(magic, version, header_size, fatbin_size) = struct.unpack(
48+
"IHHL", data[offs : offs + 16]
49+
)
50+
if magic != 0xBA55ED50 or version != 1:
51+
raise RuntimeError(
52+
f"Unexpected fatbin magic {hex(magic)} or version {version}"
53+
)
4754
if debug:
48-
print(f"Found fatbin at {offs} header_size={header_size} fatbin_size={fatbin_size}")
55+
print(
56+
f"Found fatbin at {offs} header_size={header_size} fatbin_size={fatbin_size}"
57+
)
4958
offs += header_size
5059
fatbin_end = offs + fatbin_size
5160
while offs < fatbin_end:
52-
(kind, version, hdr_size, elf_size, empty, code_ver, sm_ver) = struct.unpack('HHILLIH', data[offs: offs + 30])
61+
(kind, version, hdr_size, elf_size, empty, code_ver, sm_ver) = (
62+
struct.unpack("HHILLIH", data[offs : offs + 30])
63+
)
5364
if version != 0x0101 or kind not in [1, 2]:
54-
raise RuntimeError(f"Unexpected cubin version {hex(version)} or kind {kind}")
65+
raise RuntimeError(
66+
f"Unexpected cubin version {hex(version)} or kind {kind}"
67+
)
5568
sm_ver = f'{"ptx" if kind == 1 else "sm"}_{sm_ver}'
5669
if debug:
57-
print(f" {idx}: elf_size={elf_size} code_ver={hex(code_ver)} sm={sm_ver}")
70+
print(
71+
f" {idx}: elf_size={elf_size} code_ver={hex(code_ver)} sm={sm_ver}"
72+
)
5873
if sm_ver not in elf_sizes:
5974
elf_sizes[sm_ver] = 0
6075
elf_sizes[sm_ver] += elf_size
@@ -71,7 +86,7 @@ def __init__(self, ar_name: str) -> None:
7186
def __enter__(self) -> str:
7287
self._pwd = os.getcwd()
7388
rc = self._tmpdir.__enter__()
74-
subprocess.check_call(['ar', 'x', self.ar_name])
89+
subprocess.check_call(["ar", "x", self.ar_name])
7590
return rc
7691

7792
def __exit__(self, ex, value, tb) -> None:
@@ -86,40 +101,44 @@ def dict_add(rc: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]:
86101

87102

88103
def main():
89-
if sys.platform != 'linux':
90-
print('This script only works with Linux ELF files')
104+
if sys.platform != "linux":
105+
print("This script only works with Linux ELF files")
91106
return
92107
if len(sys.argv) < 2:
93-
print(f"{sys.argv[0]} invoked without any arguments trying to infer location of libtorch_cuda")
108+
print(
109+
f"{sys.argv[0]} invoked without any arguments trying to infer location of libtorch_cuda"
110+
)
94111
import torch
95-
fname = os.path.join(os.path.dirname(torch.__file__), 'lib', 'libtorch_cuda.so')
112+
113+
fname = os.path.join(os.path.dirname(torch.__file__), "lib", "libtorch_cuda.so")
96114
else:
97115
fname = sys.argv[1]
98116

99117
if not os.path.exists(fname):
100118
print(f"Can't find {fname}")
101119
sys.exit(-1)
102120

103-
section_names = ['.nv_fatbin', '__nv_relfatbin']
121+
section_names = [".nv_fatbin", "__nv_relfatbin"]
104122
results = {name: {} for name in section_names}
105123
print(f"Analyzing {fname}")
106-
if os.path.splitext(fname)[1] == '.a':
124+
if os.path.splitext(fname)[1] == ".a":
107125
with ArFileCtx(fname):
108126
for fname in os.listdir("."):
109-
if not fname.endswith(".o"): continue
127+
if not fname.endswith(".o"):
128+
continue
110129
for section_name in section_names:
111130
elf_sizes = compute_cubin_sizes(fname, section_name)
112131
dict_add(results[section_name], elf_sizes)
113132
else:
114-
for section_name in ['.nv_fatbin', '__nv_relfatbin']:
133+
for section_name in [".nv_fatbin", "__nv_relfatbin"]:
115134
dict_add(results[section_name], compute_cubin_sizes(fname, section_name))
116135

117136
for section_name in section_names:
118137
elf_sizes = results[section_name]
119138
print(f"{section_name} size {sizeof_fmt(sum(elf_sizes.values()))}")
120-
for (sm_ver, total_size) in elf_sizes.items():
139+
for sm_ver, total_size in elf_sizes.items():
121140
print(f" {sm_ver}: {sizeof_fmt(total_size)}")
122141

123142

124-
if __name__ == '__main__':
143+
if __name__ == "__main__":
125144
main()
Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
from collections import defaultdict
2-
from datetime import datetime, timedelta, timezone
31
import gzip
42
import os
53
import re
64
import urllib
5+
from collections import defaultdict
6+
from datetime import datetime, timedelta, timezone
77

8-
from tqdm import tqdm
98
import boto3
9+
from tqdm import tqdm
10+
11+
12+
S3 = boto3.resource("s3")
13+
CLIENT = boto3.client("s3")
14+
BUCKET = S3.Bucket("pytorch")
1015

11-
S3 = boto3.resource('s3')
12-
CLIENT = boto3.client('s3')
13-
BUCKET = S3.Bucket('pytorch')
1416

1517
class CacheEntry:
1618
_size = None
@@ -38,66 +40,56 @@ def target_arch(self) -> str:
3840

3941
@property
4042
def package_name(self) -> str:
41-
filename_contents = os.path.basename(self.download_uri).split('-')
43+
filename_contents = os.path.basename(self.download_uri).split("-")
4244
return filename_contents[0]
4345

4446
@property
4547
def package_version(self) -> str:
4648
if "dev" in self.download_uri:
47-
results = re.search(
48-
r"[0-9]+\.[0-9]+\.[0-9]+\.dev[0-9]+",
49-
self.download_uri
50-
)
49+
results = re.search(r"[0-9]+\.[0-9]+\.[0-9]+\.dev[0-9]+", self.download_uri)
5150
else:
52-
results = re.search(
53-
r"[0-9]+\.[0-9]+\.[0-9]+", self.download_uri
54-
)
51+
results = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", self.download_uri)
5552
if not results:
5653
raise Exception("Wtf there's no version o.O")
5754
return results[0]
5855

5956
@property
6057
def size(self) -> int:
6158
if self._size is None:
62-
for key in BUCKET.objects.filter(
63-
Prefix=self.download_uri.lstrip("/")
64-
):
59+
for key in BUCKET.objects.filter(Prefix=self.download_uri.lstrip("/")):
6560
self._size = key.size
6661
if self._size is None:
67-
raise Exception(
68-
f"No object found for prefix {self.download_uri}"
69-
)
62+
raise Exception(f"No object found for prefix {self.download_uri}")
7063
return self._size
7164

7265
@property
7366
def downloads(self):
7467
return self.bytes_sent // self.size
7568

69+
7670
def parse_logs(log_directory: str) -> dict:
7771
bytes_cache = {}
78-
for (dirpath, _, filenames) in os.walk(log_directory):
72+
for dirpath, _, filenames in os.walk(log_directory):
7973
for filename in tqdm(filenames):
80-
with gzip.open(os.path.join(dirpath, filename), 'r') as gf:
74+
with gzip.open(os.path.join(dirpath, filename), "r") as gf:
8175
string = gf.read().decode("utf-8")
8276
entries = []
8377
entries += string.splitlines()[2:]
8478
for entry in entries:
85-
columns = entry.split('\t')
79+
columns = entry.split("\t")
8680
bytes_sent = int(columns[3])
87-
download_uri = urllib.parse.unquote(
88-
urllib.parse.unquote(columns[7])
89-
)
81+
download_uri = urllib.parse.unquote(urllib.parse.unquote(columns[7]))
9082
status = columns[8]
91-
if not all([
92-
status.startswith("2"),
93-
download_uri.endswith((".whl", ".zip"))
94-
]):
83+
if not all(
84+
[status.startswith("2"), download_uri.endswith((".whl", ".zip"))]
85+
):
9586
continue
9687
if not bytes_cache.get(download_uri):
9788
bytes_cache[download_uri] = CacheEntry(download_uri)
9889
bytes_cache[download_uri].bytes_sent += bytes_sent
9990
return bytes_cache
10091

92+
10193
def output_results(bytes_cache: dict) -> None:
10294
os_results = defaultdict(int)
10395
arch_results = defaultdict(int)
@@ -106,25 +98,19 @@ def output_results(bytes_cache: dict) -> None:
10698
try:
10799
os_results[val.os_type] += val.downloads
108100
arch_results[val.target_arch] += val.downloads
109-
package_results[val.package_name][val.package_version] += (
110-
val.downloads
111-
)
101+
package_results[val.package_name][val.package_version] += val.downloads
112102
except Exception:
113103
pass
114104
print("=-=-= Results =-=-=")
115105
print("=-=-= OS =-=-=")
116106
total_os_num = sum(os_results.values())
117107
for os_type, num in os_results.items():
118-
print(
119-
f"\t* {os_type}: {num} ({(num/total_os_num)*100:.2f}%)"
120-
)
108+
print(f"\t* {os_type}: {num} ({(num/total_os_num)*100:.2f}%)")
121109

122110
print("=-=-= ARCH =-=-=")
123111
total_arch_num = sum(arch_results.values())
124112
for arch_type, num in arch_results.items():
125-
print(
126-
f"\t* {arch_type}: {num} ({(num/total_arch_num) * 100:.2f}%)"
127-
)
113+
print(f"\t* {arch_type}: {num} ({(num/total_arch_num) * 100:.2f}%)")
128114

129115
print("=-=-= By Package =-=-=")
130116
for package_name, upper_val in package_results.items():
@@ -135,11 +121,14 @@ def output_results(bytes_cache: dict) -> None:
135121
f"\t* {package_version}: {num} ({(num/total_package_num) * 100:.2f}%)"
136122
)
137123

124+
138125
def download_logs(log_directory: str, since: float):
139126
dt_now = datetime.now(timezone.utc)
140127
dt_end = datetime(dt_now.year, dt_now.month, dt_now.day, tzinfo=timezone.utc)
141-
dt_start = dt_end - timedelta(days=1, hours=1) # Add 1 hour padding to account for potentially missed logs due to timing
142-
for key in tqdm(BUCKET.objects.filter(Prefix='cflogs')):
128+
dt_start = dt_end - timedelta(
129+
days=1, hours=1
130+
) # Add 1 hour padding to account for potentially missed logs due to timing
131+
for key in tqdm(BUCKET.objects.filter(Prefix="cflogs")):
143132
remote_fname = key.key
144133
local_fname = os.path.join(log_directory, remote_fname)
145134
# Only download things from yesterday
@@ -156,8 +145,8 @@ def download_logs(log_directory: str, since: float):
156145

157146
if __name__ == "__main__":
158147
print("Downloading logs")
159-
download_logs('cache', 1)
148+
download_logs("cache", 1)
160149
print("Parsing logs")
161-
cache = parse_logs('cache/cflogs/')
150+
cache = parse_logs("cache/cflogs/")
162151
print("Calculating results")
163152
output_results(cache)

0 commit comments

Comments
 (0)