Skip to content

Commit b4f4c75

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Support multiple inheritance for custom dict construction (pytorch#142416)
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created pytorch#142414 to track that. Fixes pytorch#141118. Pull Request resolved: pytorch#142416 Approved by: https://github.com/jansel
1 parent b5d8d24 commit b4f4c75

File tree

4 files changed

+75
-41
lines changed

4 files changed

+75
-41
lines changed

test/dynamo/test_misc.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,6 +3015,45 @@ def fn(d, x):
30153015
self.assertEqual(fn(args2, x), opt_fn(args2, x))
30163016
self.assertEqual(cnts.frame_count, 2)
30173017

3018+
def test_mutable_mapping_multiple_inheritance(self):
3019+
class MyWeirdDict(collections.abc.MutableMapping, torch.nn.Module):
3020+
def __init__(self, **kwargs):
3021+
super().__init__()
3022+
self._items = kwargs
3023+
3024+
def keys(self):
3025+
return self._items.keys()
3026+
3027+
def __getitem__(self, item):
3028+
return self._items[item]
3029+
3030+
def __setitem__(self, key, value):
3031+
self._items[key] = value
3032+
3033+
def __delitem__(self, item):
3034+
del self._items[item]
3035+
3036+
def __len__(self):
3037+
return len(self._items)
3038+
3039+
def __iter__(self):
3040+
yield from self._items
3041+
3042+
def __hash__(self):
3043+
return hash(id(self))
3044+
3045+
def items(self):
3046+
for k, v in self._items.items():
3047+
yield (k, v)
3048+
3049+
@torch.compile(fullgraph=True)
3050+
def to_weird_dict(td):
3051+
return MyWeirdDict(**td)
3052+
3053+
d = MyWeirdDict(a=1, b=2, c=3)
3054+
res = to_weird_dict(d)
3055+
self.assertEqual(tuple(d.items()), tuple(res.items()))
3056+
30183057
def test_dunder_new_function_inlining(self):
30193058
# https://github.com/pytorch/pytorch/issues/107460
30203059

torch/_dynamo/variables/builtin.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import operator
1010
import types
1111
from collections import defaultdict, OrderedDict
12-
from collections.abc import KeysView
12+
from collections.abc import KeysView, MutableMapping
1313
from typing import Dict, List, TYPE_CHECKING
1414

1515
import torch
@@ -1394,18 +1394,28 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
13941394
return ConstDictVariable(
13951395
items, user_cls, mutation_type=ValueMutationNew()
13961396
)
1397-
elif isinstance(arg, variables.MutableMappingVariable):
1398-
# This is applicable for user defined objects which seem like dict, but are not really dicts. For
1399-
# example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
1400-
# method and create a new dict.
1397+
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
1398+
# This handles all other `MutableMapping` instances; for
1399+
# example, TensorDict which derives from MutableMapping.
1400+
#
1401+
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
1402+
# for lack of generall multiple inheritance in Dynamo. We can't
1403+
# use `isinstance(arg, MutableMappingVariable)` here because
1404+
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
1405+
# `arg.value` has multiple inheritace.
14011406
if does_not_override_dict_iter_methods(type(arg.value)):
1402-
# These are implemeted in C, so we will have to manually construct the items
1403-
1407+
# In this case, `arg.value.items()` uses the default impls,
1408+
# which are implemented in C and cannot be traced, so we
1409+
# will have to manually construct the items. This is safe
1410+
# because we know they are side-effect free.
1411+
#
1412+
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
1413+
# so we can't handle such cases by just calling
1414+
# `arg.value.items()`
14041415
if tx.output.side_effects.has_pending_mutation(arg):
14051416
unimplemented(
14061417
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
14071418
)
1408-
14091419
new_dict = dict(arg.value.items())
14101420
return VariableTracker.build(tx, new_dict)
14111421
else:

torch/_dynamo/variables/dicts.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def underlying_value(self):
8989
Hashable = ConstDictVariable._HashableTracker
9090
x = tuple(Hashable(e).underlying_value for e in self.vt.items)
9191
elif isinstance(self.vt, variables.NNModuleVariable):
92-
return self.vt.module
92+
return self.vt.value
9393
elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
9494
return self.vt.value
9595
elif isinstance(self.vt, variables.UserFunctionVariable):
@@ -277,14 +277,7 @@ def call_method(
277277
args: "List[VariableTracker]",
278278
kwargs: "Dict[str, VariableTracker]",
279279
) -> "VariableTracker":
280-
from . import (
281-
BuiltinVariable,
282-
ConstantVariable,
283-
ListIteratorVariable,
284-
ListVariable,
285-
TupleVariable,
286-
UserDefinedObjectVariable,
287-
)
280+
from . import BuiltinVariable, ConstantVariable, TupleVariable
288281

289282
Hashable = ConstDictVariable._HashableTracker
290283

@@ -344,33 +337,25 @@ def call_method(
344337
self.items.clear()
345338
return ConstantVariable.create(None)
346339
elif name == "update" and self.is_mutable():
347-
is_args_supported = len(args) == 1 and isinstance(
348-
args[0],
349-
(
350-
ConstDictVariable,
351-
ListVariable,
352-
TupleVariable,
353-
ListIteratorVariable,
354-
variables.IteratorVariable,
355-
UserDefinedObjectVariable,
356-
),
357-
)
358-
359-
is_kwargs_supported = len(kwargs) > 0 and len(args) == 0
360-
361-
if is_args_supported or is_kwargs_supported:
340+
# In general, this call looks like `a.update(b, x=1, y=2, ...)`.
341+
# Either `b` or the kwargs is omittable, but not both.
342+
has_arg = len(args) == 1
343+
has_kwargs = len(kwargs) > 0
344+
if has_arg or has_kwargs:
362345
tx.output.side_effects.mutation(self)
363-
if len(args) == 1:
346+
if has_arg:
364347
if isinstance(args[0], ConstDictVariable):
365348
dict_vt = args[0]
366349
else:
367350
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
368351
self.items.update(dict_vt.items)
369-
# Wrap strings
370-
kwargs = {
371-
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
372-
}
373-
self.items.update(kwargs)
352+
if has_kwargs:
353+
# Handle kwargs
354+
kwargs = {
355+
Hashable(ConstantVariable.create(k)): v
356+
for k, v in kwargs.items()
357+
}
358+
self.items.update(kwargs)
374359
return ConstantVariable.create(None)
375360
else:
376361
return super().call_method(tx, name, args, kwargs)

torch/_dynamo/variables/nn_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,18 @@ class NNModuleVariable(VariableTracker):
131131
_nonvar_fields = {
132132
"module_type",
133133
"module_key",
134-
"module",
134+
"value",
135135
"nn_module_stack_source",
136136
*VariableTracker._nonvar_fields,
137137
}
138138

139139
def __init__(
140-
self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
140+
self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs
141141
) -> None:
142142
super().__init__(**kwargs)
143143
self.module_type = module_type
144144
self.module_key = module_key
145-
self.module = module
145+
self.value = value
146146
assert self.source
147147
self.nn_module_stack_source = self.source
148148

0 commit comments

Comments
 (0)