Skip to content

Commit 8a67daf

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE][Easy] enable postponed annotations in tools (pytorch#129375)
Pull Request resolved: pytorch#129375 Approved by: https://github.com/malfet
1 parent 58f346c commit 8a67daf

File tree

123 files changed

+1274
-1053
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

123 files changed

+1274
-1053
lines changed

tools/alerts/create_alerts.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import argparse
46
import json
57
import os
68
import re
79
from collections import defaultdict
810
from difflib import SequenceMatcher
9-
from typing import Any, Dict, List, Set, Tuple
11+
from typing import Any
1012

1113
import requests
1214
from setuptools import distutils # type: ignore[import]
1315

16+
1417
ALL_SKIPPED_THRESHOLD = 100
1518
SIMILARITY_THRESHOLD = 0.75
1619
FAILURE_CHAIN_THRESHOLD = 2
@@ -65,14 +68,14 @@
6568

6669
class JobStatus:
6770
job_name: str = ""
68-
jobs: List[Any] = []
71+
jobs: list[Any] = []
6972
current_status: Any = None
70-
job_statuses: List[Any] = []
71-
filtered_statuses: List[Any] = []
72-
failure_chain: List[Any] = []
73-
flaky_jobs: List[Any] = []
73+
job_statuses: list[Any] = []
74+
filtered_statuses: list[Any] = []
75+
failure_chain: list[Any] = []
76+
flaky_jobs: list[Any] = []
7477

75-
def __init__(self, job_name: str, job_statuses: List[Any]):
78+
def __init__(self, job_name: str, job_statuses: list[Any]) -> None:
7679
self.job_name = job_name
7780
self.job_statuses = job_statuses
7881

@@ -93,7 +96,7 @@ def get_current_status(self) -> Any:
9396
return status
9497
return None
9598

96-
def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]:
99+
def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]:
97100
"""
98101
Returns list of jobs grouped by failureCaptures from the input list
99102
"""
@@ -120,7 +123,7 @@ def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]:
120123
return failures
121124

122125
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job
123-
def get_flaky_jobs(self) -> List[Any]:
126+
def get_flaky_jobs(self) -> list[Any]:
124127
unique_failures = self.get_unique_failures(self.filtered_statuses)
125128
flaky_jobs = []
126129
for failure in unique_failures:
@@ -134,7 +137,7 @@ def get_flaky_jobs(self) -> List[Any]:
134137

135138
# The most recent failure chain is an array of jobs that have the same-ish failures.
136139
# A success in the middle of the chain will terminate the chain.
137-
def get_most_recent_failure_chain(self) -> List[Any]:
140+
def get_most_recent_failure_chain(self) -> list[Any]:
138141
failures = []
139142
found_most_recent_failure = False
140143

@@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any:
178181

179182

180183
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
181-
def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]:
184+
def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
182185
jobData = defaultdict(list)
183186
for sha in shaGrid:
184187
for ind, job in enumerate(sha["jobs"]):
@@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool:
196199
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
197200

198201

199-
def get_failed_jobs(job_data: List[Any]) -> List[Any]:
202+
def get_failed_jobs(job_data: list[Any]) -> list[Any]:
200203
return [job for job in job_data if job["conclusion"] == "failure"]
201204

202205

203206
def classify_jobs(
204-
all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str]
205-
) -> Tuple[List[JobStatus], List[Any]]:
207+
all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str]
208+
) -> tuple[list[JobStatus], list[Any]]:
206209
"""
207210
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
208211
Classifies jobs into jobs to alert on and flaky jobs.
@@ -212,7 +215,7 @@ def classify_jobs(
212215
:return:
213216
"""
214217
job_data = map_job_data(all_job_names, sha_grid)
215-
job_statuses: List[JobStatus] = []
218+
job_statuses: list[JobStatus] = []
216219
for job in job_data:
217220
job_statuses.append(JobStatus(job, job_data[job]))
218221

@@ -230,7 +233,7 @@ def classify_jobs(
230233

231234

232235
# filter job names that don't match the regex
233-
def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
236+
def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]:
234237
if job_name_regex:
235238
return [
236239
job_name for job_name in job_names if re.match(job_name_regex, job_name)
@@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
240243

241244
def get_recurrently_failing_jobs_alerts(
242245
repo: str, branch: str, job_name_regex: str
243-
) -> List[Dict[str, Any]]:
246+
) -> list[dict[str, Any]]:
244247
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
245248

246249
filtered_job_names = set(filter_job_names(job_names, job_name_regex))

tools/autograd/gen_annotated_fn_args.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@
1414
torch/testing/_internal/generated
1515
"""
1616

17+
from __future__ import annotations
18+
1719
import argparse
1820
import os
1921
import textwrap
2022
from collections import defaultdict
21-
22-
from typing import Any, Dict, List, Sequence
23+
from typing import Any, Sequence, TYPE_CHECKING
2324

2425
import torchgen.api.python as python
2526
from torchgen.context import with_native_function
26-
2727
from torchgen.gen import parse_native_yaml
28-
from torchgen.model import Argument, BaseOperatorName, NativeFunction
2928
from torchgen.utils import FileManager
3029

3130
from .gen_python_functions import (
@@ -39,6 +38,10 @@
3938
)
4039

4140

41+
if TYPE_CHECKING:
42+
from torchgen.model import Argument, BaseOperatorName, NativeFunction
43+
44+
4245
def gen_annotated(
4346
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
4447
) -> None:
@@ -53,9 +56,9 @@ def gen_annotated(
5356
(is_py_fft_function, "torch._C._fft"),
5457
(is_py_variable_method, "torch.Tensor"),
5558
)
56-
annotated_args: List[str] = []
59+
annotated_args: list[str] = []
5760
for pred, namespace in mappings:
58-
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
61+
groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
5962
for f in native_functions:
6063
if not should_generate_py_binding(f) or not pred(f):
6164
continue
@@ -77,7 +80,7 @@ def gen_annotated(
7780

7881
@with_native_function
7982
def gen_annotated_args(f: NativeFunction) -> str:
80-
def _get_kwargs_func_exclusion_list() -> List[str]:
83+
def _get_kwargs_func_exclusion_list() -> list[str]:
8184
# functions that currently don't work with kwargs in test_overrides.py
8285
return [
8386
"diagonal",
@@ -87,12 +90,12 @@ def _get_kwargs_func_exclusion_list() -> List[str]:
8790
]
8891

8992
def _add_out_arg(
90-
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
93+
out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
9194
) -> None:
9295
for arg in args:
9396
if arg.default is not None:
9497
continue
95-
out_arg: Dict[str, Any] = {}
98+
out_arg: dict[str, Any] = {}
9699
out_arg["is_kwarg_only"] = str(is_kwarg_only)
97100
out_arg["name"] = arg.name
98101
out_arg["simple_type"] = python.argument_type_str(
@@ -103,7 +106,7 @@ def _add_out_arg(
103106
out_arg["size"] = size_t
104107
out_args.append(out_arg)
105108

106-
out_args: List[Dict[str, Any]] = []
109+
out_args: list[dict[str, Any]] = []
107110
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
108111
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
109112
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)

tools/autograd/gen_autograd.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
# gen_python_functions.py: generates Python bindings to THPVariable
2323
#
2424

25+
from __future__ import annotations
26+
2527
import argparse
2628
import os
27-
from typing import List
2829

2930
from torchgen.api import cpp
3031
from torchgen.api.autograd import (
@@ -69,7 +70,7 @@ def gen_autograd(
6970
),
7071
key=lambda f: cpp.name(f.func),
7172
)
72-
fns_with_diff_infos: List[
73+
fns_with_diff_infos: list[
7374
NativeFunctionWithDifferentiabilityInfo
7475
] = match_differentiability_info(fns, differentiability_infos)
7576

tools/autograd/gen_autograd_functions.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# Functions.h/cpp: subclasses of autograd::Node
55
# python_functions.h/cpp: Python bindings for the above classes
66
#
7-
from typing import Dict, List, Sequence, Tuple
7+
8+
from __future__ import annotations
9+
10+
from typing import Sequence
811

912
from torchgen.api.autograd import (
1013
Derivative,
@@ -43,6 +46,7 @@
4346

4447
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
4548

49+
4650
FUNCTION_DECLARATION = CodeTemplate(
4751
"""\
4852
#ifdef _WIN32
@@ -443,8 +447,8 @@
443447

444448

445449
def get_infos_with_derivatives_list(
446-
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
447-
) -> List[DifferentiabilityInfo]:
450+
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
451+
) -> list[DifferentiabilityInfo]:
448452
diff_info_list = [
449453
info
450454
for diffinfo_dict in differentiability_infos.values()
@@ -456,7 +460,7 @@ def get_infos_with_derivatives_list(
456460

457461
def gen_autograd_functions_lib(
458462
out: str,
459-
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
463+
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
460464
template_path: str,
461465
) -> None:
462466
"""Functions.h and Functions.cpp body
@@ -490,7 +494,7 @@ def gen_autograd_functions_lib(
490494

491495
def gen_autograd_functions_python(
492496
out: str,
493-
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
497+
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
494498
template_path: str,
495499
) -> None:
496500
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
@@ -536,17 +540,17 @@ def gen_autograd_functions_python(
536540

537541

538542
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
539-
saved_variables: List[str] = []
540-
release_variables: List[str] = []
541-
saved_list_sizes: List[str] = []
542-
unpack: List[str] = []
543-
asserts: List[str] = []
544-
compute_index_ranges: List[str] = []
545-
getter_definitions: List[str] = []
546-
py_getsetdef_structs: List[str] = []
547-
compiled_args: List[str] = []
548-
apply_with_saved_before: List[str] = []
549-
apply_with_saved_after: List[str] = []
543+
saved_variables: list[str] = []
544+
release_variables: list[str] = []
545+
saved_list_sizes: list[str] = []
546+
unpack: list[str] = []
547+
asserts: list[str] = []
548+
compute_index_ranges: list[str] = []
549+
getter_definitions: list[str] = []
550+
py_getsetdef_structs: list[str] = []
551+
compiled_args: list[str] = []
552+
apply_with_saved_before: list[str] = []
553+
apply_with_saved_after: list[str] = []
550554

551555
for arg in info.args_with_derivatives:
552556
if arg.type in TENSOR_LIST_LIKE_CTYPES:
@@ -807,7 +811,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
807811
else:
808812
will_release_variables = ""
809813

810-
body: List[str] = []
814+
body: list[str] = []
811815

812816
if uses_single_grad(info):
813817
body.append("const auto& grad = grads[0];")
@@ -821,7 +825,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
821825
def emit_derivative(
822826
derivative: Derivative,
823827
args_with_derivatives: Sequence[Binding],
824-
) -> Tuple[bool, str]:
828+
) -> tuple[bool, str]:
825829
formula = derivative.formula
826830
var_names = derivative.var_names
827831
if len(var_names) == 1:
@@ -857,7 +861,7 @@ def emit_derivative(
857861
else:
858862
grad_input_mask = ""
859863
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
860-
copy_ranges: List[str] = []
864+
copy_ranges: list[str] = []
861865
for i, n in enumerate(var_names):
862866
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
863867
return False, DERIVATIVE_MULTI.substitute(

0 commit comments

Comments
 (0)