diff --git a/lib-topaz/topaz.rb b/lib-topaz/topaz.rb index e915a862e..691e71050 100644 --- a/lib-topaz/topaz.rb +++ b/lib-topaz/topaz.rb @@ -1,4 +1,28 @@ module Topaz + def self.recursion_guard_outer(identifier, obj, &block) + # We want to throw something less likely to be caught accidentally outside + # our own code than the recursion identifier. Ideally this should be an + # object that is unique to this particular recursion guard. Since doing + # that properly requires pushing extra state all the way up into + # ExecutionContext, we do this instead. + throw_symbol = "__recursion_guard_#{identifier}".to_sym + + if Thread.current.in_recursion_guard?(identifier) + Thread.current.recursion_guard(identifier, obj) do + yield + return false + end + throw(throw_symbol) + else + Thread.current.recursion_guard(identifier, obj) do + catch(throw_symbol) do + yield + return false + end + return true + end + end + end end lib_topaz = File.join(File.dirname(__FILE__), 'topaz') diff --git a/tests/objects/test_threadobject.py b/tests/objects/test_threadobject.py index 64724c196..158b61736 100644 --- a/tests/objects/test_threadobject.py +++ b/tests/objects/test_threadobject.py @@ -15,44 +15,3 @@ def test_thread_local_storage(self, space): return Thread.current[:a] """) assert space.int_w(w_res) == 1 - - def test_recursion_guard(self, space): - w_res = space.execute(""" - def foo(objs, depth = 0) - obj = objs.shift - recursion = Thread.current.recursion_guard(:foo, obj) do - return foo(objs, depth + 1) - end - if recursion - return [depth, obj] - end - end - return foo([:a, :b, :c, :a, :d]) - """) - w_depth, w_symbol = space.listview(w_res) - assert space.int_w(w_depth) == 3 - assert space.symbol_w(w_symbol) == "a" - - def test_recursion_guard_nested(self, space): - w_res = space.execute(""" - def foo(objs, depth = 0) - obj = objs.shift - Thread.current.recursion_guard(:foo, obj) do - return bar(objs, depth + 1) - end - return [depth, obj] - end - - def bar(objs, depth) - obj = objs.shift - Thread.current.recursion_guard(:bar, obj) do - return foo(objs, depth + 1) - end - return [depth, obj] - end - - return foo([:a, :a, :b, :b, :c, :a, :d, :d]) - """) - w_depth, w_symbol = space.listview(w_res) - assert space.int_w(w_depth) == 5 - assert space.symbol_w(w_symbol) == "a" diff --git a/tests/test_recursion_guard.py b/tests/test_recursion_guard.py new file mode 100644 index 000000000..410072e6f --- /dev/null +++ b/tests/test_recursion_guard.py @@ -0,0 +1,56 @@ +class TestRecursionGuard(object): + def test_recursion_guard(self, space): + w_res = space.execute(""" + def foo(objs, depth = 0) + obj = objs.shift + recursion = Thread.current.recursion_guard(:foo, obj) do + return foo(objs, depth + 1) + end + if recursion + return [depth, obj] + end + end + return foo([:a, :b, :c, :a, :d]) + """) + w_depth, w_symbol = space.listview(w_res) + assert space.int_w(w_depth) == 3 + assert space.symbol_w(w_symbol) == "a" + + def test_recursion_guard_nested(self, space): + w_res = space.execute(""" + def foo(objs, depth = 0) + obj = objs.shift + Thread.current.recursion_guard(:foo, obj) do + return bar(objs, depth + 1) + end + return [depth, obj] + end + + def bar(objs, depth) + obj = objs.shift + Thread.current.recursion_guard(:bar, obj) do + return foo(objs, depth + 1) + end + return [depth, obj] + end + + return foo([:a, :a, :b, :b, :c, :a, :d, :d]) + """) + w_depth, w_symbol = space.listview(w_res) + assert space.int_w(w_depth) == 5 + assert space.symbol_w(w_symbol) == "a" + + def test_recursion_guard_outer(self, space): + w_res = space.execute(""" + def foo(objs, depth = 0) + obj = objs.shift + Topaz.recursion_guard_outer(:foo, obj) do + return foo(objs, depth + 1) + end + return [depth, obj] + end + return foo([:a, :b, :c, :a, :d]) + """) + w_depth, w_symbol = space.listview(w_res) + assert space.int_w(w_depth) == 0 + assert space.symbol_w(w_symbol) == "a" diff --git a/topaz/objects/threadobject.py b/topaz/objects/threadobject.py index ed51cd74e..6d9cafaa6 100644 --- a/topaz/objects/threadobject.py +++ b/topaz/objects/threadobject.py @@ -42,3 +42,11 @@ def method_recursion_guard(self, space, w_identifier, w_obj, block): if not in_recursion: space.invoke_block(block, []) return space.newbool(in_recursion) + + @classdef.method("in_recursion_guard?") + def method_in_recursion_guardp(self, space, w_identifier): + ec = space.getexecutioncontext() + identifier = space.symbol_w(w_identifier) + if identifier in ec.recursive_calls: + return space.w_true + return space.w_false