Skip to content

Commit f2cfe8b

Browse files
aorenstepytorchmergebot
authored andcommitted
PEP585 update - mostly toplevels (pytorch#145178)
See pytorch#145101 for details. Pull Request resolved: pytorch#145178 Approved by: https://github.com/bobrenjc93
1 parent 1ce5338 commit f2cfe8b

39 files changed

+356
-386
lines changed

torch/_C/_cudnn.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from enum import Enum
22

3-
from torch.types import _bool, Tuple
3+
from torch.types import _bool
44

55
# Defined in torch/csrc/cuda/shared/cudnn.cpp
66
is_cuda: _bool
77

8-
def getRuntimeVersion() -> Tuple[int, int, int]: ...
9-
def getCompileVersion() -> Tuple[int, int, int]: ...
8+
def getRuntimeVersion() -> tuple[int, int, int]: ...
9+
def getCompileVersion() -> tuple[int, int, int]: ...
1010
def getVersionInt() -> int: ...
1111

1212
class RNNMode(int, Enum):

torch/__init__.py

+20-24
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,9 @@
2424
from typing import (
2525
Any as _Any,
2626
Callable as _Callable,
27-
Dict as _Dict,
2827
get_origin as _get_origin,
2928
Optional as _Optional,
3029
overload as _overload,
31-
Set as _Set,
32-
Tuple as _Tuple,
33-
Type as _Type,
3430
TYPE_CHECKING,
3531
TypeVar as _TypeVar,
3632
Union as _Union,
@@ -337,7 +333,7 @@ def _load_global_deps() -> None:
337333
except OSError as err:
338334
# Can only happen for wheel with cuda libs as PYPI deps
339335
# As PyTorch is not purelib, but nvidia-*-cu12 is
340-
cuda_libs: _Dict[str, str] = {
336+
cuda_libs: dict[str, str] = {
341337
"cublas": "libcublas.so.*[0-9]",
342338
"cudnn": "libcudnn.so.*[0-9]",
343339
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
@@ -586,7 +582,7 @@ def __hash__(self) -> builtins.int:
586582
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
587583
# return hash(builtins.int(self))
588584

589-
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
585+
def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
590586
"""Represent this int as an exact integer ratio"""
591587
return self, 1
592588

@@ -698,7 +694,7 @@ def is_integer(self):
698694
"""Return True if the float is an integer."""
699695
raise TypeError("type stub not overridden")
700696

701-
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
697+
def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
702698
"""Represent this float as an exact integer ratio"""
703699
return builtins.float(self).as_integer_ratio()
704700

@@ -857,22 +853,22 @@ def sym_max(a, b):
857853
assert isinstance(a, all_types), type(a)
858854
assert isinstance(b, all_types), type(b)
859855
if isinstance(a, float_types) or isinstance(b, float_types):
860-
return builtins.float(builtins.max(a, b))
856+
return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
861857
else:
862-
return builtins.max(a, b)
858+
return builtins.max(a, b) # type: ignore[call-overload]
863859

864860

865-
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
861+
def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
866862
try:
867863
import numpy as np
868864

869-
all_types: _Tuple[_Type, ...] = (
865+
all_types: tuple[type, ...] = (
870866
np.integer,
871867
np.floating,
872868
builtins.int,
873869
builtins.float,
874870
)
875-
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
871+
float_types: tuple[type, ...] = (np.floating, builtins.float)
876872
except ModuleNotFoundError:
877873
all_types = (builtins.int, builtins.float)
878874
float_types = (builtins.float,)
@@ -894,9 +890,9 @@ def sym_min(a, b):
894890
assert isinstance(a, all_types), type(a)
895891
assert isinstance(b, all_types), type(b)
896892
if isinstance(a, float_types) or isinstance(b, float_types):
897-
return builtins.float(builtins.min(a, b))
893+
return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
898894
else:
899-
return builtins.min(a, b)
895+
return builtins.min(a, b) # type: ignore[call-overload]
900896

901897

902898
def sym_sum(args):
@@ -1204,7 +1200,7 @@ def set_default_device(
12041200
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
12051201

12061202

1207-
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
1203+
def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None:
12081204
r"""
12091205
.. warning::
12101206
@@ -2007,7 +2003,7 @@ def _dtype(self):
20072003
return torch.quint2x4
20082004

20092005

2010-
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
2006+
_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = {
20112007
UntypedStorage,
20122008
DoubleStorage,
20132009
FloatStorage,
@@ -2030,7 +2026,7 @@ def _dtype(self):
20302026
}
20312027

20322028
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
2033-
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
2029+
_tensor_classes: set[type["torch.Tensor"]] = set()
20342030

20352031
# If you edit these imports, please update torch/__init__.py.in as well
20362032
from torch import amp as amp, random as random, serialization as serialization
@@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
22822278
def __init__(self, mode, options, dynamic):
22832279
from torch._inductor.compiler_bisector import CompilerBisector
22842280

2285-
self.config: _Dict[str, _Any] = {}
2281+
self.config: dict[str, _Any] = {}
22862282
self.dynamic = dynamic
22872283
self.apply_mode(mode)
22882284
self.apply_options(options)
@@ -2309,13 +2305,13 @@ def apply_mode(self, mode: _Optional[str]):
23092305

23102306
self.apply_options(list_mode_options(mode, self.dynamic))
23112307

2312-
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
2308+
def apply_options(self, options: _Optional[dict[str, _Any]]):
23132309
if not options:
23142310
return
23152311

23162312
from torch._inductor import config
23172313

2318-
current_config: _Dict[str, _Any] = config.get_config_copy()
2314+
current_config: dict[str, _Any] = config.get_config_copy()
23192315

23202316
for key, val in options.items():
23212317
attr_name = key.replace("-", "_")
@@ -2403,7 +2399,7 @@ def compile(
24032399
dynamic: _Optional[builtins.bool] = None,
24042400
backend: _Union[str, _Callable] = "inductor",
24052401
mode: _Union[str, None] = None,
2406-
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2402+
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
24072403
disable: builtins.bool = False,
24082404
) -> _Callable[_InputT, _RetT]: ...
24092405

@@ -2416,7 +2412,7 @@ def compile(
24162412
dynamic: _Optional[builtins.bool] = None,
24172413
backend: _Union[str, _Callable] = "inductor",
24182414
mode: _Union[str, None] = None,
2419-
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2415+
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
24202416
disable: builtins.bool = False,
24212417
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
24222418

@@ -2428,7 +2424,7 @@ def compile(
24282424
dynamic: _Optional[builtins.bool] = None,
24292425
backend: _Union[str, _Callable] = "inductor",
24302426
mode: _Union[str, None] = None,
2431-
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2427+
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
24322428
disable: builtins.bool = False,
24332429
) -> _Union[
24342430
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
@@ -2624,7 +2620,7 @@ def _register_device_module(device_type, module):
26242620

26252621
class _TritonLibrary:
26262622
lib = torch.library.Library("triton", "DEF")
2627-
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
2623+
ops_table: dict[tuple[str, str], _Callable] = {}
26282624

26292625
@classmethod
26302626
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):

torch/_custom_op/impl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def inner(func):
108108
# An example usage is FakeTensor: FakeTensor checks if a specific operator
109109
# has an implementation registered via the CustomOp API.
110110
# Indexed by qualname (e.g. aten::foo)
111-
global_registry: typing.Dict[str, "CustomOp"] = {}
111+
global_registry: dict[str, "CustomOp"] = {}
112112

113113

114114
class CustomOp:
@@ -136,7 +136,7 @@ def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_acc
136136
self.__name__ = None # mypy requires this
137137
# NB: Some of these impls are registered as kernels to DispatchKeys.
138138
# Modifying the _impls dict directly won't do anything in that case.
139-
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
139+
self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
140140
# See NOTE [CustomOp autograd kernel indirection]
141141
self._registered_autograd_kernel_indirection = False
142142

@@ -476,7 +476,7 @@ def validate_schema(schema: FunctionSchema) -> None:
476476
)
477477

478478

479-
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
479+
def parse_qualname(qualname: str) -> tuple[str, str]:
480480
names = qualname.split("::", 1)
481481
if len(names) != 2:
482482
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "

torch/_dispatch/python.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# mypy: allow-untyped-defs
22
import itertools
33
import unittest.mock
4+
from collections.abc import Iterator
45
from contextlib import contextmanager
5-
from typing import Iterator
66

77
import torch
88
import torch._C

torch/_guards.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,9 @@
1717
from typing import (
1818
Any,
1919
Callable,
20-
Dict,
2120
Generic,
22-
List,
2321
NamedTuple,
2422
Optional,
25-
Set,
26-
Tuple,
2723
TYPE_CHECKING,
2824
TypeVar,
2925
Union,
@@ -260,8 +256,8 @@ class Guard:
260256
create_fn: Callable[[GuardBuilderBase, Guard], None]
261257

262258
# Export only. These values are written to at time of guard check_fn creation.
263-
guard_types: Optional[List[str]] = None
264-
code_list: Optional[List[str]] = None
259+
guard_types: Optional[list[str]] = None
260+
code_list: Optional[list[str]] = None
265261
obj_weakref: Optional[object] = None
266262
guarded_class_weakref: Optional[type] = None
267263

@@ -448,8 +444,8 @@ def __post_init__(self):
448444

449445
@dataclasses.dataclass
450446
class StorageOverlap(GuardEnvExpr):
451-
overlapping_sources: List[Source]
452-
non_overlapping_sources: List[Source]
447+
overlapping_sources: list[Source]
448+
non_overlapping_sources: list[Source]
453449

454450

455451
"""
@@ -478,7 +474,7 @@ class GuardsCheckpointState:
478474
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
479475
"""
480476

481-
dynamo_guards: Set[Guard] = set()
477+
dynamo_guards: set[Guard] = set()
482478

483479
def __init__(self, dynamo_guards):
484480
self.dynamo_guards = dynamo_guards
@@ -500,7 +496,7 @@ def __eq__(self, other):
500496

501497

502498
class ModuleContextCheckpointState:
503-
nn_modules: Dict[str, torch.nn.Module] = {}
499+
nn_modules: dict[str, torch.nn.Module] = {}
504500

505501
def __init__(self, nn_modules):
506502
self.nn_modules = nn_modules
@@ -523,7 +519,7 @@ def __eq__(self, other):
523519

524520
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
525521
def __init__(self) -> None:
526-
self.nn_modules: Dict[str, Any] = {}
522+
self.nn_modules: dict[str, Any] = {}
527523

528524
def copy_graphstate(self):
529525
return ModuleContextCheckpointState(dict(self.nn_modules))
@@ -534,7 +530,7 @@ def restore_graphstate(self, state):
534530

535531

536532
class GlobalContextCheckpointState:
537-
global_state: Dict[str, Tuple[Callable, ...]] = {}
533+
global_state: dict[str, tuple[Callable, ...]] = {}
538534

539535
def __init__(self, global_states):
540536
self.global_state = global_states
@@ -572,7 +568,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
572568
}
573569

574570
def __init__(self) -> None:
575-
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
571+
self.global_state: dict[str, tuple[Callable, ...]] = {}
576572

577573
def copy_graphstate(self):
578574
return GlobalContextCheckpointState(dict(self.global_state))
@@ -628,7 +624,7 @@ def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
628624
guard.user_stack = TracingContext.extract_stack()
629625
self.inner.add(guard)
630626

631-
def update(self, *others: Set[Guard]):
627+
def update(self, *others: set[Guard]):
632628
for o in others:
633629
for g in o:
634630
self.add(g, skip=1)
@@ -641,7 +637,7 @@ def remove_guards_with_source(self, source):
641637
class GuardsContext(Checkpointable[GuardsCheckpointState]):
642638
def __init__(self) -> None:
643639
self.dynamo_guards: GuardsSet = GuardsSet()
644-
self.aotautograd_guards: List[GuardEnvExpr] = []
640+
self.aotautograd_guards: list[GuardEnvExpr] = []
645641

646642
def copy_graphstate(self):
647643
return GuardsCheckpointState(set(self.dynamo_guards.inner))
@@ -674,9 +670,9 @@ def get_proxy_dispatch_entry(self, identifier: str): ...
674670

675671
class InvokeSubgraphCache(HopSubgraphCache):
676672
def __init__(self) -> None:
677-
self.autograd_cache: Dict[str, Callable] = {}
678-
self.proxy_dispatch_cache: Dict[str, Callable] = {}
679-
self.dynamo_identifiers: Dict[str, str] = {}
673+
self.autograd_cache: dict[str, Callable] = {}
674+
self.proxy_dispatch_cache: dict[str, Callable] = {}
675+
self.dynamo_identifiers: dict[str, str] = {}
680676

681677
def add_dynamo_identifier(self, cache_key: str, identifier: str):
682678
self.dynamo_identifiers[cache_key] = identifier
@@ -748,7 +744,7 @@ def __init__(self, compile_id):
748744
self.compile_id: Optional[CompileId] = compile_id
749745
self.attempt = 0
750746
# Verbose ShapeEnv guards produced.
751-
self.shape_env_guards: List[str] = []
747+
self.shape_env_guards: list[str] = []
752748

753749
@staticmethod
754750
def current_compile_id():
@@ -816,7 +812,7 @@ def __init__(self, fake_mode):
816812
# careful not to accidentally induce guards on the SymInt if
817813
# you ever do change this in aot_autograd.py; you should check
818814
# on permutations preferentially.)
819-
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
815+
self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
820816
# When this is True, whenever we encounter an int in Dynamo tracing,
821817
# we will (1) force unspec it and (2) force it as a size-like unbacked
822818
# integer. This is currently used when processing certain lists of

0 commit comments

Comments
 (0)