Skip to content

Commit 9bbe4a6

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[dynamo] support maxlen for collections.deque (pytorch#138194)
Pull Request resolved: pytorch#138194 Approved by: https://github.com/jansel, https://github.com/malfet
1 parent a4b3576 commit 9bbe4a6

File tree

4 files changed

+69
-24
lines changed

4 files changed

+69
-24
lines changed

test/dynamo/test_misc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9963,7 +9963,7 @@ def fn(x):
99639963
"c": (
99649964
x,
99659965
3.0,
9966-
collections.deque([0.0, -x]),
9966+
collections.deque([0.0, -x, 1, 2], maxlen=3),
99679967
),
99689968
"d": collections.OrderedDict(
99699969
{
@@ -9995,7 +9995,7 @@ def fn(x, y):
99959995
"c": (
99969996
x,
99979997
3.0,
9998-
[0.0, -x],
9998+
collections.deque([0.0, -x, 1, 2], maxlen=3),
99999999
),
1000010000
"d": collections.OrderedDict(
1000110001
{
@@ -10011,6 +10011,7 @@ def fn(x, y):
1001110011
x * y,
1001210012
3.0,
1001310013
y - 2,
10014+
1,
1001410015
torch.zeros(2, 2),
1001510016
2 * y,
1001610017
-y,
@@ -10043,7 +10044,7 @@ def fn(x, y):
1004310044
"c": (
1004410045
x,
1004510046
3.0,
10046-
[0.0, -x],
10047+
collections.deque([0.0, -x, 1, 2], maxlen=3),
1004710048
),
1004810049
"d": collections.OrderedDict(
1004910050
{
@@ -10054,7 +10055,7 @@ def fn(x, y):
1005410055
}
1005510056
tree2 = collections.OrderedDict(
1005610057
[
10057-
("c", (y, 3.0, [-y, 10.0])),
10058+
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
1005810059
("a", [y, y + 1]),
1005910060
("b", y + 2),
1006010061
(

torch/_dynamo/variables/lists.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,33 @@ def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTrack
452452

453453

454454
class DequeVariable(CommonListMethodsVariable):
455+
def __init__(self, items, maxlen=None, **kwargs) -> None:
456+
if maxlen is None:
457+
maxlen = ConstantVariable.create(None)
458+
assert (
459+
maxlen.is_python_constant()
460+
), f"maxlen must be a constant, got: {maxlen.debug_repr()}"
461+
self.maxlen = maxlen
462+
if self.maxlen.as_python_constant() is not None:
463+
items = list(items)[-maxlen.as_python_constant() :]
464+
super().__init__(items, **kwargs)
465+
455466
def python_type(self):
456467
return collections.deque
457468

458469
def debug_repr(self):
470+
if self.maxlen.as_python_constant() is None:
471+
return self.debug_repr_helper(
472+
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
473+
)
459474
return self.debug_repr_helper("deque([", "])")
460475

476+
def as_python_constant(self):
477+
return self.python_type()(
478+
[x.as_python_constant() for x in self.items],
479+
maxlen=self.maxlen.as_python_constant(),
480+
)
481+
461482
def reconstruct(self, codegen: "PyCodegen") -> None:
462483
assert "deque" not in codegen.tx.f_globals
463484
codegen.add_push_null(
@@ -466,12 +487,14 @@ def reconstruct(self, codegen: "PyCodegen") -> None:
466487
)
467488
)
468489
codegen.foreach(self.items)
469-
codegen.extend_output(
470-
[
471-
create_instruction("BUILD_LIST", arg=len(self.items)),
472-
*create_call_function(1, False),
473-
]
474-
)
490+
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
491+
codegen(self.maxlen)
492+
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
493+
494+
def var_getattr(self, tx: "InstructionTranslator", name):
495+
if name == "maxlen":
496+
return self.maxlen
497+
return super().var_getattr(tx, name)
475498

476499
def call_method(
477500
self,
@@ -494,33 +517,37 @@ def call_method(
494517
tx.output.side_effects.mutation(self)
495518
self.items[key.as_python_constant()] = value
496519
return ConstantVariable.create(None)
497-
elif (
520+
521+
if (
498522
name == "extendleft"
499523
and self.mutable_local
500524
and args[0].has_force_unpack_var_sequence(tx)
501525
):
502526
assert not kwargs
503-
504527
(arg,) = args
505528
prefix = arg.force_unpack_var_sequence(tx)
506529
prefix.reverse()
507530
tx.output.side_effects.mutation(self)
508531
self.items = prefix + list(self.items)
509-
return ConstantVariable.create(None)
532+
result = ConstantVariable.create(None)
510533
elif name == "popleft" and self.mutable_local:
511534
assert not args
512535
assert not kwargs
513536
item = self.items[0]
514537
tx.output.side_effects.mutation(self)
515538
self.items = self.items[1:]
516-
return item
539+
result = item
517540
elif name == "appendleft" and self.mutable_local:
518541
assert not kwargs
519542
tx.output.side_effects.mutation(self)
520543
self.items = [args[0]] + list(self.items)
521-
return ConstantVariable.create(None)
544+
result = ConstantVariable.create(None)
522545
else:
523-
return super().call_method(tx, name, args, kwargs)
546+
result = super().call_method(tx, name, args, kwargs)
547+
548+
if self.maxlen.as_python_constant() is not None:
549+
self.items = list(self.items)[-self.maxlen.as_python_constant() :]
550+
return result
524551

525552

526553
class TupleVariable(BaseListVariable):

torch/_dynamo/variables/user_defined.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,31 @@ def call_function(
374374
if self.value.__optional_keys__:
375375
unimplemented("TypedDict with optional keys not supported")
376376
return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs)
377-
elif self.value is collections.deque and not kwargs:
378-
if len(args) == 0:
379-
items = []
380-
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
381-
items = args[0].force_unpack_var_sequence(tx)
377+
elif self.value is collections.deque:
378+
maxlen = variables.ConstantVariable.create(None)
379+
if not kwargs:
380+
if len(args) == 0:
381+
items = []
382+
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
383+
items = args[0].force_unpack_var_sequence(tx)
384+
elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx):
385+
items = args[0].force_unpack_var_sequence(tx)
386+
maxlen = args[1]
387+
else:
388+
unimplemented("deque() with more than 2 arg not supported")
389+
elif tuple(kwargs) == ("maxlen",):
390+
maxlen = kwargs["maxlen"]
391+
if len(args) == 0:
392+
items = []
393+
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
394+
items = args[0].force_unpack_var_sequence(tx)
395+
else:
396+
unimplemented("deque() with more than 1 arg not supported")
382397
else:
383-
unimplemented("deque() with more than 1 arg not supported")
384-
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
398+
unimplemented("deque() with invalid kwargs not supported")
399+
return variables.lists.DequeVariable(
400+
items, maxlen=maxlen, mutable_local=MutableLocal()
401+
)
385402
elif self.value is functools.partial:
386403
if not args:
387404
unimplemented("functools.partial malformed")

torch/utils/_pytree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def tree_map_(
994994
"""
995995
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
996996
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
997-
tuple(map(func, *flat_args)) # consume and exhaust the iterable
997+
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
998998
return tree
999999

10001000

0 commit comments

Comments
 (0)