Skip to content

Commit 9731ccb

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Type _dynamo/variables/lazy.py (pytorch#136376)
Pull Request resolved: pytorch#136376 Approved by: https://github.com/Skylion007
1 parent 0971563 commit 9731ccb

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

Diff for: torch/_dynamo/variables/lazy.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# mypy: allow-untyped-defs
21
import collections
32
import functools
4-
from typing import Optional
3+
from typing import Any, Callable, Dict, Optional, Tuple, Union
4+
from typing_extensions import Self
55

66
from .base import VariableTracker
77
from .tensor import SymNodeVariable
@@ -10,14 +10,14 @@
1010
class LazyCache:
1111
"""Container to cache the real VariableTracker"""
1212

13-
def __init__(self, value, source) -> None:
13+
def __init__(self, value: Any, source: Any) -> None:
1414
if not isinstance(value, LazySymNodeFormatString):
1515
assert source
1616
self.value = value
1717
self.source = source
1818
self.vt: Optional[VariableTracker] = None
1919

20-
def realize(self):
20+
def realize(self) -> None:
2121
assert self.vt is None
2222
from ..symbolic_convert import InstructionTranslator
2323
from .builder import SourcelessBuilder, VariableBuilder
@@ -49,10 +49,10 @@ class LazyVariableTracker(VariableTracker):
4949
_nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
5050

5151
@staticmethod
52-
def create(value, source, **options):
52+
def create(value: Any, source: Any, **options: Any) -> "LazyVariableTracker":
5353
return LazyVariableTracker(LazyCache(value, source), source=source, **options)
5454

55-
def __init__(self, _cache, **kwargs) -> None:
55+
def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
5656
assert isinstance(_cache, LazyCache)
5757
super().__init__(**kwargs)
5858
self._cache = _cache
@@ -64,16 +64,17 @@ def realize(self) -> VariableTracker:
6464
assert self._cache.vt is not None
6565
return self._cache.vt
6666

67-
def unwrap(self):
67+
def unwrap(self) -> Union[VariableTracker, Self]:
6868
"""Return the real VariableTracker if it already exists"""
6969
if self.is_realized():
70+
assert self._cache.vt is not None
7071
return self._cache.vt
7172
return self
7273

73-
def is_realized(self):
74+
def is_realized(self) -> bool:
7475
return self._cache.vt is not None
7576

76-
def clone(self, **kwargs):
77+
def clone(self, **kwargs: Any) -> VariableTracker:
7778
assert kwargs.get("_cache", self._cache) is self._cache
7879
if kwargs.get("source", self.source) is not self.source:
7980
self.realize()
@@ -84,7 +85,7 @@ def __str__(self) -> str:
8485
return self.unwrap().__str__()
8586
return VariableTracker.__str__(self.unwrap())
8687

87-
def __getattr__(self, item):
88+
def __getattr__(self, item: str) -> Any:
8889
return getattr(self.realize(), item)
8990

9091
# most methods are auto-generated below, these are the ones we want to exclude
@@ -94,9 +95,9 @@ def __getattr__(self, item):
9495
@classmethod
9596
def realize_all(
9697
cls,
97-
value,
98-
cache=None,
99-
):
98+
value: Any,
99+
cache: Optional[Dict[int, Tuple[Any, Any]]] = None,
100+
) -> Any:
100101
"""
101102
Walk an object and realize all LazyVariableTrackers inside it.
102103
"""
@@ -150,15 +151,19 @@ def __str__(self) -> str:
150151
)
151152

152153

153-
def _create_realize_and_forward(name):
154+
def _create_realize_and_forward(
155+
name: str,
156+
) -> Callable[[LazyVariableTracker, Any, Any], Any]:
154157
@functools.wraps(getattr(VariableTracker, name))
155-
def realize_and_forward(self, *args, **kwargs):
158+
def realize_and_forward(
159+
self: LazyVariableTracker, *args: Any, **kwargs: Any
160+
) -> Any:
156161
return getattr(self.realize(), name)(*args, **kwargs)
157162

158163
return realize_and_forward
159164

160165

161-
def _populate():
166+
def _populate() -> None:
162167
for name, value in VariableTracker.__dict__.items():
163168
if name not in LazyVariableTracker.__dict__:
164169
if callable(value):

0 commit comments

Comments
 (0)