Skip to content

Commit f2ce2d4

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Add test for returning a nested recursive function and update documentation (pytorch#141713)
Addresses pytorch#137905 (comment). Pull Request resolved: pytorch#141713 Approved by: https://github.com/jansel
1 parent f8a64c3 commit f2ce2d4

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

test/dynamo/test_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3646,6 +3646,18 @@ def fn(a, b):
36463646
self.assertIsInstance(it2, enumerate)
36473647
self.assertEqual(list(it1), list(it2))
36483648

3649+
def test_returning_recursive_func(self):
3650+
@torch.compile(backend="eager", fullgraph=True)
3651+
def run(x):
3652+
def f():
3653+
return f
3654+
3655+
return x + 1, f
3656+
3657+
res, f = run(torch.zeros(1))
3658+
self.assertTrue(same(res, torch.ones(1)))
3659+
self.assertTrue(f is f())
3660+
36493661

36503662
def udf_mul(x, y):
36513663
return x * y

torch/_dynamo/side_effects.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,10 @@ def track_tensor_variables_from_runahead_side_effects(self, other):
348348
self.track_object_existing(other_item, other_variable)
349349

350350
def prune_dead_object_new(self, tx):
351+
# Avoid VT cycles from e.g., recursive function.
352+
visited: Set[VariableTracker] = set()
351353
live_new_objects: Set[VariableTracker] = set()
352354

353-
# use this to avoid cycles in mutation_type (though I'm not sure if that
354-
# can actually happen).
355-
visited: Set[VariableTracker] = set({})
356-
357355
def visit(var: VariableTracker):
358356
if var in visited:
359357
return

0 commit comments

Comments
 (0)