Skip to content

Commit 9eb0520

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Fix side-effect handling for pre-existing collections.deque (pytorch#141714)
Previously we never replayed side effects to `DequeVariable` with a source; the bug was already in the `test_deque_input` test, but went unnoticed because we didn't check the deque objects. This patch adds limited but practical support for this (see comments in `side_effects.py` for why limited), and updates the deque tests to check for this. Pull Request resolved: pytorch#141714 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141713
1 parent f2ce2d4 commit 9eb0520

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

test/dynamo/test_misc.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9387,11 +9387,8 @@ def start():
93879387
def test_deque_input(self):
93889388
a = torch.randn([2, 3])
93899389
b = torch.randn([2, 3])
9390-
d1 = collections.deque([a, b])
9391-
d1.insert(0, "foo")
9392-
9393-
d2 = collections.deque([a, b])
9394-
d2.insert(0, "foo")
9390+
d1 = collections.deque(["foo", a, b])
9391+
d2 = d1.copy()
93959392

93969393
def fn(q):
93979394
a = q.pop()
@@ -9400,16 +9397,14 @@ def fn(q):
94009397

94019398
eager = fn(d1)
94029399
counter = CompileCounter()
9403-
compiled = torch.compile(fn, backend=counter)(d2)
9400+
compiled = torch.compile(fn, backend=counter, fullgraph=True)(d2)
9401+
self.assertEqual(d1, d2)
94049402
self.assertEqual(eager, compiled)
94059403
self.assertEqual(counter.frame_count, 1)
94069404

94079405
def test_deque_append_left(self):
9408-
d1 = collections.deque([10, 10])
9409-
d1.insert(0, "foo")
9410-
9411-
d2 = collections.deque([10, 10])
9412-
d2.insert(0, "foo")
9406+
d1 = collections.deque(["foo", 10, 10])
9407+
d2 = d1.copy()
94139408

94149409
def fn(q, a, b):
94159410
q.appendleft(a)
@@ -9420,7 +9415,8 @@ def fn(q, a, b):
94209415
b = torch.randn([3, 3])
94219416
eager = fn(d1, a, b)
94229417
counter = CompileCounter()
9423-
compiled = torch.compile(fn, backend=counter)(d2, a, b)
9418+
compiled = torch.compile(fn, backend=counter, fullgraph=True)(d2, a, b)
9419+
self.assertEqual(d1, d2)
94249420
self.assertEqual(eager, compiled)
94259421
self.assertEqual(counter.frame_count, 1)
94269422
self.assertTrue(isinstance(compiled, torch.Tensor))

torch/_dynamo/side_effects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,36 @@ def codegen_update_mutated(self, cg: PyCodegen):
562562
]
563563
)
564564
suffixes.append([create_instruction("STORE_SUBSCR")])
565+
elif isinstance(var, variables.lists.DequeVariable):
566+
# For limited maxlen, the order of operations matter for side
567+
# effect, but we currently don't track the order, so no support.
568+
if not (
569+
isinstance(var.maxlen, variables.ConstantVariable)
570+
and var.maxlen.value is None
571+
):
572+
unimplemented("side effect on existing deque with limited maxlen")
573+
574+
# old.extend(new), this runs last
575+
cg(var.source)
576+
cg.load_method("extend")
577+
cg(var, allow_cache=False) # Don't codegen via source
578+
suffixes.append(
579+
[
580+
*create_call_method(1),
581+
create_instruction("POP_TOP"),
582+
]
583+
)
584+
585+
# old.clear(), this runs first
586+
cg(var.source)
587+
cg.load_method("clear")
588+
suffixes.append(
589+
[
590+
*create_call_method(0),
591+
create_instruction("POP_TOP"),
592+
]
593+
)
594+
565595
elif isinstance(var, variables.CustomizedDictVariable):
566596
# need to update the dict manually since update method may be invalid
567597
varname_map = {}

torch/_dynamo/variables/builder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,10 +1326,8 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
13261326
)
13271327
tensor_list_proxy.node.meta["grapharg"] = grapharg
13281328

1329-
result = BaseListVariable.cls_for_instance(value)(
1330-
output, mutation_type=ValueMutationNew()
1331-
)
1332-
if istype(value, list):
1329+
result = BaseListVariable.cls_for_instance(value)(output)
1330+
if istype(value, (list, collections.deque)):
13331331
return self.set_source_and_track_mutable(value, result)
13341332
return result
13351333

0 commit comments

Comments
 (0)