24
24
from typing import (
25
25
Any as _Any ,
26
26
Callable as _Callable ,
27
- Dict as _Dict ,
28
27
get_origin as _get_origin ,
29
28
Optional as _Optional ,
30
29
overload as _overload ,
31
- Set as _Set ,
32
- Tuple as _Tuple ,
33
- Type as _Type ,
34
30
TYPE_CHECKING ,
35
31
TypeVar as _TypeVar ,
36
32
Union as _Union ,
@@ -337,7 +333,7 @@ def _load_global_deps() -> None:
337
333
except OSError as err :
338
334
# Can only happen for wheel with cuda libs as PYPI deps
339
335
# As PyTorch is not purelib, but nvidia-*-cu12 is
340
- cuda_libs : _Dict [str , str ] = {
336
+ cuda_libs : dict [str , str ] = {
341
337
"cublas" : "libcublas.so.*[0-9]" ,
342
338
"cudnn" : "libcudnn.so.*[0-9]" ,
343
339
"cuda_nvrtc" : "libnvrtc.so.*[0-9]" ,
@@ -586,7 +582,7 @@ def __hash__(self) -> builtins.int:
586
582
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
587
583
# return hash(builtins.int(self))
588
584
589
- def as_integer_ratio (self ) -> _Tuple ["SymInt" , builtins .int ]:
585
+ def as_integer_ratio (self ) -> tuple ["SymInt" , builtins .int ]:
590
586
"""Represent this int as an exact integer ratio"""
591
587
return self , 1
592
588
@@ -698,7 +694,7 @@ def is_integer(self):
698
694
"""Return True if the float is an integer."""
699
695
raise TypeError ("type stub not overridden" )
700
696
701
- def as_integer_ratio (self ) -> _Tuple [builtins .int , builtins .int ]:
697
+ def as_integer_ratio (self ) -> tuple [builtins .int , builtins .int ]:
702
698
"""Represent this float as an exact integer ratio"""
703
699
return builtins .float (self ).as_integer_ratio ()
704
700
@@ -857,22 +853,22 @@ def sym_max(a, b):
857
853
assert isinstance (a , all_types ), type (a )
858
854
assert isinstance (b , all_types ), type (b )
859
855
if isinstance (a , float_types ) or isinstance (b , float_types ):
860
- return builtins .float (builtins .max (a , b ))
856
+ return builtins .float (builtins .max (a , b )) # type: ignore[call-overload]
861
857
else :
862
- return builtins .max (a , b )
858
+ return builtins .max (a , b ) # type: ignore[call-overload]
863
859
864
860
865
- def __all_and_float_types () -> _Tuple [ _Tuple [ _Type , ...], _Tuple [ _Type , ...]]:
861
+ def __all_and_float_types () -> tuple [ tuple [ type , ...], tuple [ type , ...]]:
866
862
try :
867
863
import numpy as np
868
864
869
- all_types : _Tuple [ _Type , ...] = (
865
+ all_types : tuple [ type , ...] = (
870
866
np .integer ,
871
867
np .floating ,
872
868
builtins .int ,
873
869
builtins .float ,
874
870
)
875
- float_types : _Tuple [ _Type , ...] = (np .floating , builtins .float )
871
+ float_types : tuple [ type , ...] = (np .floating , builtins .float )
876
872
except ModuleNotFoundError :
877
873
all_types = (builtins .int , builtins .float )
878
874
float_types = (builtins .float ,)
@@ -894,9 +890,9 @@ def sym_min(a, b):
894
890
assert isinstance (a , all_types ), type (a )
895
891
assert isinstance (b , all_types ), type (b )
896
892
if isinstance (a , float_types ) or isinstance (b , float_types ):
897
- return builtins .float (builtins .min (a , b ))
893
+ return builtins .float (builtins .min (a , b )) # type: ignore[call-overload]
898
894
else :
899
- return builtins .min (a , b )
895
+ return builtins .min (a , b ) # type: ignore[call-overload]
900
896
901
897
902
898
def sym_sum (args ):
@@ -1204,7 +1200,7 @@ def set_default_device(
1204
1200
_GLOBAL_DEVICE_CONTEXT .device_context = device_context
1205
1201
1206
1202
1207
- def set_default_tensor_type (t : _Union [_Type ["torch.Tensor" ], str ], / ) -> None :
1203
+ def set_default_tensor_type (t : _Union [type ["torch.Tensor" ], str ], / ) -> None :
1208
1204
r"""
1209
1205
.. warning::
1210
1206
@@ -2007,7 +2003,7 @@ def _dtype(self):
2007
2003
return torch .quint2x4
2008
2004
2009
2005
2010
- _storage_classes : _Set [ _Type [_Union [TypedStorage , UntypedStorage ]]] = {
2006
+ _storage_classes : set [ type [_Union [TypedStorage , UntypedStorage ]]] = {
2011
2007
UntypedStorage ,
2012
2008
DoubleStorage ,
2013
2009
FloatStorage ,
@@ -2030,7 +2026,7 @@ def _dtype(self):
2030
2026
}
2031
2027
2032
2028
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
2033
- _tensor_classes : _Set [ _Type ["torch.Tensor" ]] = set ()
2029
+ _tensor_classes : set [ type ["torch.Tensor" ]] = set ()
2034
2030
2035
2031
# If you edit these imports, please update torch/__init__.py.in as well
2036
2032
from torch import amp as amp , random as random , serialization as serialization
@@ -2282,7 +2278,7 @@ class _TorchCompileInductorWrapper:
2282
2278
def __init__ (self , mode , options , dynamic ):
2283
2279
from torch ._inductor .compiler_bisector import CompilerBisector
2284
2280
2285
- self .config : _Dict [str , _Any ] = {}
2281
+ self .config : dict [str , _Any ] = {}
2286
2282
self .dynamic = dynamic
2287
2283
self .apply_mode (mode )
2288
2284
self .apply_options (options )
@@ -2309,13 +2305,13 @@ def apply_mode(self, mode: _Optional[str]):
2309
2305
2310
2306
self .apply_options (list_mode_options (mode , self .dynamic ))
2311
2307
2312
- def apply_options (self , options : _Optional [_Dict [str , _Any ]]):
2308
+ def apply_options (self , options : _Optional [dict [str , _Any ]]):
2313
2309
if not options :
2314
2310
return
2315
2311
2316
2312
from torch ._inductor import config
2317
2313
2318
- current_config : _Dict [str , _Any ] = config .get_config_copy ()
2314
+ current_config : dict [str , _Any ] = config .get_config_copy ()
2319
2315
2320
2316
for key , val in options .items ():
2321
2317
attr_name = key .replace ("-" , "_" )
@@ -2403,7 +2399,7 @@ def compile(
2403
2399
dynamic : _Optional [builtins .bool ] = None ,
2404
2400
backend : _Union [str , _Callable ] = "inductor" ,
2405
2401
mode : _Union [str , None ] = None ,
2406
- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2402
+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2407
2403
disable : builtins .bool = False ,
2408
2404
) -> _Callable [_InputT , _RetT ]: ...
2409
2405
@@ -2416,7 +2412,7 @@ def compile(
2416
2412
dynamic : _Optional [builtins .bool ] = None ,
2417
2413
backend : _Union [str , _Callable ] = "inductor" ,
2418
2414
mode : _Union [str , None ] = None ,
2419
- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2415
+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2420
2416
disable : builtins .bool = False ,
2421
2417
) -> _Callable [[_Callable [_InputT , _RetT ]], _Callable [_InputT , _RetT ]]: ...
2422
2418
@@ -2428,7 +2424,7 @@ def compile(
2428
2424
dynamic : _Optional [builtins .bool ] = None ,
2429
2425
backend : _Union [str , _Callable ] = "inductor" ,
2430
2426
mode : _Union [str , None ] = None ,
2431
- options : _Optional [_Dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2427
+ options : _Optional [dict [str , _Union [str , builtins .int , builtins .bool ]]] = None ,
2432
2428
disable : builtins .bool = False ,
2433
2429
) -> _Union [
2434
2430
_Callable [[_Callable [_InputT , _RetT ]], _Callable [_InputT , _RetT ]],
@@ -2624,7 +2620,7 @@ def _register_device_module(device_type, module):
2624
2620
2625
2621
class _TritonLibrary :
2626
2622
lib = torch .library .Library ("triton" , "DEF" )
2627
- ops_table : _Dict [ _Tuple [str , str ], _Callable ] = {}
2623
+ ops_table : dict [ tuple [str , str ], _Callable ] = {}
2628
2624
2629
2625
@classmethod
2630
2626
def registerOp (cls , op_key , full_schema , op_impl , dispatch_key ):
0 commit comments