1
- # mypy: allow-untyped-defs
2
1
import collections
3
2
import functools
4
- from typing import Optional
3
+ from typing import Any , Callable , Dict , Optional , Tuple , Union
4
+ from typing_extensions import Self
5
5
6
6
from .base import VariableTracker
7
7
from .tensor import SymNodeVariable
10
10
class LazyCache :
11
11
"""Container to cache the real VariableTracker"""
12
12
13
- def __init__ (self , value , source ) -> None :
13
+ def __init__ (self , value : Any , source : Any ) -> None :
14
14
if not isinstance (value , LazySymNodeFormatString ):
15
15
assert source
16
16
self .value = value
17
17
self .source = source
18
18
self .vt : Optional [VariableTracker ] = None
19
19
20
- def realize (self ):
20
+ def realize (self ) -> None :
21
21
assert self .vt is None
22
22
from ..symbolic_convert import InstructionTranslator
23
23
from .builder import SourcelessBuilder , VariableBuilder
@@ -49,10 +49,10 @@ class LazyVariableTracker(VariableTracker):
49
49
_nonvar_fields = {"_cache" , * VariableTracker ._nonvar_fields }
50
50
51
51
@staticmethod
52
- def create (value , source , ** options ) :
52
+ def create (value : Any , source : Any , ** options : Any ) -> "LazyVariableTracker" :
53
53
return LazyVariableTracker (LazyCache (value , source ), source = source , ** options )
54
54
55
- def __init__ (self , _cache , ** kwargs ) -> None :
55
+ def __init__ (self , _cache : LazyCache , ** kwargs : Any ) -> None :
56
56
assert isinstance (_cache , LazyCache )
57
57
super ().__init__ (** kwargs )
58
58
self ._cache = _cache
@@ -64,16 +64,17 @@ def realize(self) -> VariableTracker:
64
64
assert self ._cache .vt is not None
65
65
return self ._cache .vt
66
66
67
- def unwrap (self ):
67
+ def unwrap (self ) -> Union [ VariableTracker , Self ] :
68
68
"""Return the real VariableTracker if it already exists"""
69
69
if self .is_realized ():
70
+ assert self ._cache .vt is not None
70
71
return self ._cache .vt
71
72
return self
72
73
73
- def is_realized (self ):
74
+ def is_realized (self ) -> bool :
74
75
return self ._cache .vt is not None
75
76
76
- def clone (self , ** kwargs ) :
77
+ def clone (self , ** kwargs : Any ) -> VariableTracker :
77
78
assert kwargs .get ("_cache" , self ._cache ) is self ._cache
78
79
if kwargs .get ("source" , self .source ) is not self .source :
79
80
self .realize ()
@@ -84,7 +85,7 @@ def __str__(self) -> str:
84
85
return self .unwrap ().__str__ ()
85
86
return VariableTracker .__str__ (self .unwrap ())
86
87
87
- def __getattr__ (self , item ) :
88
+ def __getattr__ (self , item : str ) -> Any :
88
89
return getattr (self .realize (), item )
89
90
90
91
# most methods are auto-generated below, these are the ones we want to exclude
@@ -94,9 +95,9 @@ def __getattr__(self, item):
94
95
@classmethod
95
96
def realize_all (
96
97
cls ,
97
- value ,
98
- cache = None ,
99
- ):
98
+ value : Any ,
99
+ cache : Optional [ Dict [ int , Tuple [ Any , Any ]]] = None ,
100
+ ) -> Any :
100
101
"""
101
102
Walk an object and realize all LazyVariableTrackers inside it.
102
103
"""
@@ -150,15 +151,19 @@ def __str__(self) -> str:
150
151
)
151
152
152
153
153
- def _create_realize_and_forward (name ):
154
+ def _create_realize_and_forward (
155
+ name : str ,
156
+ ) -> Callable [[LazyVariableTracker , Any , Any ], Any ]:
154
157
@functools .wraps (getattr (VariableTracker , name ))
155
- def realize_and_forward (self , * args , ** kwargs ):
158
+ def realize_and_forward (
159
+ self : LazyVariableTracker , * args : Any , ** kwargs : Any
160
+ ) -> Any :
156
161
return getattr (self .realize (), name )(* args , ** kwargs )
157
162
158
163
return realize_and_forward
159
164
160
165
161
- def _populate ():
166
+ def _populate () -> None :
162
167
for name , value in VariableTracker .__dict__ .items ():
163
168
if name not in LazyVariableTracker .__dict__ :
164
169
if callable (value ):
0 commit comments