1- # mypy: allow-untyped-defs
21import collections
32import functools
4- from typing import Optional
3+ from typing import Any , Callable , Dict , Optional , Tuple , Union
4+ from typing_extensions import Self
55
66from .base import VariableTracker
77from .tensor import SymNodeVariable
1010class LazyCache :
1111 """Container to cache the real VariableTracker"""
1212
13- def __init__ (self , value , source ) -> None :
13+ def __init__ (self , value : Any , source : Any ) -> None :
1414 if not isinstance (value , LazySymNodeFormatString ):
1515 assert source
1616 self .value = value
1717 self .source = source
1818 self .vt : Optional [VariableTracker ] = None
1919
20- def realize (self ):
20+ def realize (self ) -> None :
2121 assert self .vt is None
2222 from ..symbolic_convert import InstructionTranslator
2323 from .builder import SourcelessBuilder , VariableBuilder
@@ -49,10 +49,10 @@ class LazyVariableTracker(VariableTracker):
4949 _nonvar_fields = {"_cache" , * VariableTracker ._nonvar_fields }
5050
5151 @staticmethod
52- def create (value , source , ** options ) :
52+ def create (value : Any , source : Any , ** options : Any ) -> "LazyVariableTracker" :
5353 return LazyVariableTracker (LazyCache (value , source ), source = source , ** options )
5454
55- def __init__ (self , _cache , ** kwargs ) -> None :
55+ def __init__ (self , _cache : LazyCache , ** kwargs : Any ) -> None :
5656 assert isinstance (_cache , LazyCache )
5757 super ().__init__ (** kwargs )
5858 self ._cache = _cache
@@ -64,16 +64,17 @@ def realize(self) -> VariableTracker:
6464 assert self ._cache .vt is not None
6565 return self ._cache .vt
6666
67- def unwrap (self ):
67+ def unwrap (self ) -> Union [ VariableTracker , Self ] :
6868 """Return the real VariableTracker if it already exists"""
6969 if self .is_realized ():
70+ assert self ._cache .vt is not None
7071 return self ._cache .vt
7172 return self
7273
73- def is_realized (self ):
74+ def is_realized (self ) -> bool :
7475 return self ._cache .vt is not None
7576
76- def clone (self , ** kwargs ) :
77+ def clone (self , ** kwargs : Any ) -> VariableTracker :
7778 assert kwargs .get ("_cache" , self ._cache ) is self ._cache
7879 if kwargs .get ("source" , self .source ) is not self .source :
7980 self .realize ()
@@ -84,7 +85,7 @@ def __str__(self) -> str:
8485 return self .unwrap ().__str__ ()
8586 return VariableTracker .__str__ (self .unwrap ())
8687
87- def __getattr__ (self , item ) :
88+ def __getattr__ (self , item : str ) -> Any :
8889 return getattr (self .realize (), item )
8990
9091 # most methods are auto-generated below, these are the ones we want to exclude
@@ -94,9 +95,9 @@ def __getattr__(self, item):
9495 @classmethod
9596 def realize_all (
9697 cls ,
97- value ,
98- cache = None ,
99- ):
98+ value : Any ,
99+ cache : Optional [ Dict [ int , Tuple [ Any , Any ]]] = None ,
100+ ) -> Any :
100101 """
101102 Walk an object and realize all LazyVariableTrackers inside it.
102103 """
@@ -150,15 +151,19 @@ def __str__(self) -> str:
150151 )
151152
152153
153- def _create_realize_and_forward (name ):
154+ def _create_realize_and_forward (
155+ name : str ,
156+ ) -> Callable [[LazyVariableTracker , Any , Any ], Any ]:
154157 @functools .wraps (getattr (VariableTracker , name ))
155- def realize_and_forward (self , * args , ** kwargs ):
158+ def realize_and_forward (
159+ self : LazyVariableTracker , * args : Any , ** kwargs : Any
160+ ) -> Any :
156161 return getattr (self .realize (), name )(* args , ** kwargs )
157162
158163 return realize_and_forward
159164
160165
161- def _populate ():
166+ def _populate () -> None :
162167 for name , value in VariableTracker .__dict__ .items ():
163168 if name not in LazyVariableTracker .__dict__ :
164169 if callable (value ):
0 commit comments