2424from typing import (
2525 Any as _Any ,
2626 Callable as _Callable ,
27- Dict as _Dict ,
2827 get_origin as _get_origin ,
2928 Optional as _Optional ,
3029 overload as _overload ,
31- Set as _Set ,
32- Tuple as _Tuple ,
33- Type as _Type ,
3430 TYPE_CHECKING ,
3531 TypeVar as _TypeVar ,
3632 Union as _Union ,
@@ -337,7 +333,7 @@ def _load_global_deps() -> None:
337333 except OSError as err :
338334 # Can only happen for wheel with cuda libs as PYPI deps
339335 # As PyTorch is not purelib, but nvidia-*-cu12 is
340- cuda_libs : _Dict [str , str ] = {
336+ cuda_libs : dict [str , str ] = {
341337 "cublas" : "libcublas.so.*[0-9]" ,
342338 "cudnn" : "libcudnn.so.*[0-9]" ,
343339 "cuda_nvrtc" : "libnvrtc.so.*[0-9]" ,
@@ -586,7 +582,7 @@ def __hash__(self) -> builtins.int:
586582 # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
587583 # return hash(builtins.int(self))
588584
589- def as_integer_ratio (self ) -> _Tuple ["SymInt" , builtins .int ]:
585+ def as_integer_ratio (self ) -> tuple ["SymInt" , builtins .int ]:
590586 """Represent this int as an exact integer ratio"""
591587 return self , 1
592588
@@ -698,7 +694,7 @@ def is_integer(self):
698694 """Return True if the float is an integer."""
699695 raise TypeError ("type stub not overridden" )
700696
701- def as_integer_ratio (self ) -> _Tuple [builtins .int , builtins .int ]:
697+ def as_integer_ratio (self ) -> tuple [builtins .int , builtins .int ]:
702698 """Represent this float as an exact integer ratio"""
703699 return builtins .float (self ).as_integer_ratio ()
704700
@@ -857,22 +853,22 @@ def sym_max(a, b):
857853 assert isinstance (a , all_types ), type (a )
858854 assert isinstance (b , all_types ), type (b )
859855 if isinstance (a , float_types ) or isinstance (b , float_types ):
860- return builtins .float (builtins .max (a , b ))
856+ return builtins .float (builtins .max (a , b )) # type: ignore[call-overload]
861857 else :
862- return builtins .max (a , b )
858+ return builtins .max (a , b ) # type: ignore[call-overload]
863859
864860
865- def __all_and_float_types () -> _Tuple [ _Tuple [ _Type , ...], _Tuple [ _Type , ...]]:
861+ def __all_and_float_types () -> tuple [ tuple [ type , ...], tuple [ type , ...]]:
866862 try :
867863 import numpy as np
868864
869- all_types : _Tuple [ _Type , ...] = (
865+ all_types : tuple [ type , ...] = (
870866 np .integer ,
871867 np .floating ,
872868 builtins .int ,
873869 builtins .float ,
874870 )
875- float_types : _Tuple [ _Type , ...] = (np .floating , builtins .float )
871+ float_types : tuple [ type , ...] = (np .floating , builtins .float )
876872 except ModuleNotFoundError :
877873 all_types = (builtins .int , builtins .float )
878874 float_types = (builtins .float ,)
@@ -894,9 +890,9 @@ def sym_min(a, b):
894890 assert isinstance (a , all_types ), type (a )
895891 assert isinstance (b , all_types ), type (b )
896892 if isinstance (a , float_types ) or isinstance (b , float_types ):
897- return builtins .float (builtins .min (a , b ))
893+ return builtins .float (builtins .min (a , b )) # type: ignore[call-overload]
898894 else :
899- return builtins .min (a , b )
895+ return builtins .min (a , b ) # type: ignore[call-overload]
900896
901897
902898def sym_sum (args ):
@@ -1204,7 +1200,7 @@ def set_default_device(
12041200 _GLOBAL_DEVICE_CONTEXT .device_context = device_context
12051201
12061202
1207- def set_default_tensor_type (t : _Union [_Type ["torch.Tensor" ], str ], / ) -> None :
1203+ def set_default_tensor_type (t : _Union [type ["torch.Tensor" ], str ], / ) -> None :
12081204 r"""
12091205 .. warning::
12101206
@@ -2007,7 +2003,7 @@ def _dtype(self):
20072003 return torch .quint2x4
20082004
20092005
2010- _storage_classes : _Set [ _Type [_Union [TypedStorage , UntypedStorage ]]] = {
2006+ _storage_classes : set [ type [_Union [TypedStorage , UntypedStorage ]]] = {
20112007 UntypedStorage ,
20122008 DoubleStorage ,
20132009 FloatStorage ,
@@ -2030,7 +2026,7 @@ def _dtype(self):
20302026}
20312027
20322028# The _tensor_classes set is initialized by the call to initialize_python_bindings.
2033- _tensor_classes : _Set [ _Type ["torch.Tensor" ]] = set ()
2029+ _tensor_classes : set [ type ["torch.Tensor" ]] = set ()
20342030
20352031# If you edit these imports, please update torch/__init__.py.in as well
20362032from torch import amp as amp , random as random , serialization as serialization
@@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
22822278 def __init__ (self , mode , options , dynamic ):
22832279 from torch ._inductor .compiler_bisector import CompilerBisector
22842280
2285- self .config : _Dict [str , _Any ] = {}
2281+ self .config : dict [str , _Any ] = {}
22862282 self .dynamic = dynamic
22872283 self .apply_mode (mode )
22882284 self .apply_options (options )
@@ -2309,13 +2305,13 @@ def apply_mode(self, mode: _Optional[str]):
23092305
23102306 self .apply_options (list_mode_options (mode , self .dynamic ))
23112307
2312- def apply_options (self , options : _Optional [_Dict [str , _Any ]]):
2308+ def apply_options (self , options : _Optional [dict [str , _Any ]]):
23132309 if not options :
23142310 return
23152311
23162312 from torch ._inductor import config
23172313
2318- current_config : _Dict [str , _Any ] = config .get_config_copy ()
2314+ current_config : dict [str , _Any ] = config .get_config_copy ()
23192315
23202316 for key , val in options .items ():
23212317 attr_name = key .replace ("-" , "_" )
@@ -2403,7 +2399,7 @@ def compile(
24032399 dynamic : _Optional [builtins .bool ] = None ,
24042400 backend : _Union [str , _Callable ] = "inductor" ,
24052401 mode : _Union [str , None ] = None ,
2406- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2402+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
24072403 disable : builtins .bool = False ,
24082404) -> _Callable [_InputT , _RetT ]: ...
24092405
@@ -2416,7 +2412,7 @@ def compile(
24162412 dynamic : _Optional [builtins .bool ] = None ,
24172413 backend : _Union [str , _Callable ] = "inductor" ,
24182414 mode : _Union [str , None ] = None ,
2419- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2415+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
24202416 disable : builtins .bool = False ,
24212417) -> _Callable [[_Callable [_InputT , _RetT ]], _Callable [_InputT , _RetT ]]: ...
24222418
@@ -2428,7 +2424,7 @@ def compile(
24282424 dynamic : _Optional [builtins .bool ] = None ,
24292425 backend : _Union [str , _Callable ] = "inductor" ,
24302426 mode : _Union [str , None ] = None ,
2431- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2427+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
24322428 disable : builtins .bool = False ,
24332429) -> _Union [
24342430 _Callable [[_Callable [_InputT , _RetT ]], _Callable [_InputT , _RetT ]],
@@ -2624,7 +2620,7 @@ def _register_device_module(device_type, module):
26242620
26252621 class _TritonLibrary :
26262622 lib = torch .library .Library ("triton" , "DEF" )
2627- ops_table : _Dict [ _Tuple [str , str ], _Callable ] = {}
2623+ ops_table : dict [ tuple [str , str ], _Callable ] = {}
26282624
26292625 @classmethod
26302626 def registerOp (cls , op_key , full_schema , op_impl , dispatch_key ):
0 commit comments