Skip to content

Commit a8bed71

Browse files
voznesenskympytorchmergebot
authored andcommitted
[Easy] use BaseListVariable cls_for for all list-y type dispatching (pytorch#110159)
Pull Request resolved: pytorch#110159 Approved by: https://github.com/ezyang
1 parent ec5bbef commit a8bed71

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

torch/_dynamo/variables/builder.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@
100100
from .higher_order_ops import TorchHigherOrderOperatorVariable
101101
from .lists import (
102102
BaseListVariable,
103-
DequeVariable,
104103
ListVariable,
105104
NamedTupleVariable,
106105
RangeVariable,
@@ -260,20 +259,6 @@ def _common_constants():
260259
# dynamic_shapes
261260
}
262261

263-
@staticmethod
264-
def list_type(value):
265-
if is_namedtuple(value):
266-
return functools.partial(NamedTupleVariable, tuple_cls=type(value))
267-
# TODO(voz): Why do we have both this and `BaseListVariable`'s `cls_for`?
268-
return {
269-
tuple: TupleVariable,
270-
list: ListVariable,
271-
odict_values: ListVariable,
272-
torch.nn.ParameterList: ListVariable,
273-
torch.nn.ModuleList: ListVariable,
274-
collections.deque: DequeVariable,
275-
}[type(value)]
276-
277262
def get_source(self):
278263
return self.source
279264

@@ -835,7 +820,7 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
835820
).add_guards(guards)
836821
for i, item in enumerate(value)
837822
]
838-
result = self.list_type(value)(
823+
result = BaseListVariable.cls_for_instance(value)(
839824
output, mutable_local=MutableLocal(), guards=guards
840825
)
841826
if istype(value, list):

torch/_dynamo/variables/lists.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from ..exc import unimplemented
1414
from ..guards import make_dupe_guard
1515
from ..source import GetItemSource
16-
from ..utils import get_fake_value, guard_if_dyn, namedtuple_fields
16+
from ..utils import (
17+
get_fake_value,
18+
guard_if_dyn,
19+
is_namedtuple,
20+
namedtuple_fields,
21+
odict_values,
22+
)
1723
from .base import MutableLocal, VariableTracker
1824
from .constant import ConstantVariable
1925
from .functions import UserFunctionVariable, UserMethodVariable
@@ -43,6 +49,12 @@ def _listlike_contains_helper(items, search, tx, options):
4349

4450

4551
class BaseListVariable(VariableTracker):
52+
@staticmethod
53+
def cls_for_instance(obj):
54+
if is_namedtuple(obj):
55+
return functools.partial(NamedTupleVariable, tuple_cls=type(obj))
56+
return BaseListVariable.cls_for(type(obj))
57+
4658
@staticmethod
4759
def cls_for(obj):
4860
return {
@@ -52,6 +64,9 @@ def cls_for(obj):
5264
torch.Size: SizeVariable,
5365
tuple: TupleVariable,
5466
set: SetVariable,
67+
odict_values: ListVariable,
68+
torch.nn.ParameterList: ListVariable,
69+
torch.nn.ModuleList: ListVariable,
5570
collections.deque: DequeVariable,
5671
}[obj]
5772

0 commit comments

Comments
 (0)