Skip to content

Commit 2457810

Browse files
authored
add internal async_dispatch util (#15813)
1 parent ab964c1 commit 2457810

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import asyncio
2+
import inspect
3+
from functools import wraps
4+
from typing import Any, Callable, Coroutine, Protocol, TypeVar, Union
5+
6+
from typing_extensions import ParamSpec
7+
8+
R = TypeVar("R")
9+
P = ParamSpec("P")
10+
11+
12+
class AsyncDispatchable(Protocol[P, R]):
13+
"""Protocol for functions decorated with async_dispatch."""
14+
15+
def __call__(
16+
self, *args: P.args, **kwargs: P.kwargs
17+
) -> Union[R, Coroutine[Any, Any, R]]:
18+
...
19+
20+
aio: Callable[P, Coroutine[Any, Any, R]]
21+
sync: Callable[P, R]
22+
23+
24+
def is_in_async_context() -> bool:
25+
"""Check if we're in an async context."""
26+
try:
27+
# First check if we're in a coroutine
28+
if asyncio.current_task() is not None:
29+
return True
30+
31+
# Check if we have a loop and it's running
32+
loop = asyncio.get_event_loop()
33+
return loop.is_running()
34+
except RuntimeError:
35+
return False
36+
37+
38+
def async_dispatch(
39+
async_impl: Callable[P, Coroutine[Any, Any, R]],
40+
) -> Callable[[Callable[P, R]], AsyncDispatchable[P, R]]:
41+
"""
42+
Decorator that adds async compatibility to a sync function.
43+
44+
The decorated function will:
45+
- Return a coroutine when in an async context (detected via running event loop)
46+
- Run synchronously when in a sync context
47+
- Provide .aio for explicit async access
48+
- Provide .sync for explicit sync access
49+
50+
Args:
51+
async_impl: The async implementation to dispatch to when async execution
52+
is needed
53+
"""
54+
if not inspect.iscoroutinefunction(async_impl):
55+
raise TypeError(
56+
"async_impl must be an async function to dispatch in async contexts"
57+
)
58+
59+
def decorator(sync_fn: Callable[P, R]) -> AsyncDispatchable[P, R]:
60+
@wraps(sync_fn)
61+
def wrapper(
62+
*args: P.args, **kwargs: P.kwargs
63+
) -> Union[R, Coroutine[Any, Any, R]]:
64+
if is_in_async_context():
65+
return async_impl(*args, **kwargs)
66+
return sync_fn(*args, **kwargs)
67+
68+
# Attach both async and sync implementations directly
69+
wrapper.aio = async_impl
70+
wrapper.sync = sync_fn
71+
return wrapper # type: ignore
72+
73+
return decorator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from prefect._internal.compatibility.async_dispatch import (
6+
async_dispatch,
7+
is_in_async_context,
8+
)
9+
from prefect.utilities.asyncutils import run_sync_in_worker_thread
10+
11+
12+
class TestAsyncDispatchBasicUsage:
13+
def test_async_compatible_fn_in_sync_context(self):
14+
data = []
15+
16+
async def my_function_async():
17+
data.append("async")
18+
19+
@async_dispatch(my_function_async)
20+
def my_function():
21+
data.append("sync")
22+
23+
my_function()
24+
assert data == ["sync"]
25+
26+
async def test_async_compatible_fn_in_async_context(self):
27+
data = []
28+
29+
async def my_function_async():
30+
data.append("async")
31+
32+
@async_dispatch(my_function_async)
33+
def my_function():
34+
data.append("sync")
35+
36+
await my_function()
37+
assert data == ["async"]
38+
39+
40+
class TestAsyncDispatchExplicitUsage:
41+
async def test_async_compatible_fn_explicit_async_usage(self):
42+
"""Verify .aio property works as expected"""
43+
data = []
44+
45+
async def my_function_async():
46+
data.append("async")
47+
48+
@async_dispatch(my_function_async)
49+
def my_function():
50+
data.append("sync")
51+
52+
await my_function.aio()
53+
assert data == ["async"]
54+
55+
def test_async_compatible_fn_explicit_async_usage_with_asyncio_run(self):
56+
"""Verify .aio property works as expected with asyncio.run"""
57+
data = []
58+
59+
async def my_function_async():
60+
data.append("async")
61+
62+
@async_dispatch(my_function_async)
63+
def my_function():
64+
data.append("sync")
65+
66+
asyncio.run(my_function.aio())
67+
assert data == ["async"]
68+
69+
async def test_async_compatible_fn_explicit_sync_usage(self):
70+
"""Verify .sync property works as expected in async context"""
71+
data = []
72+
73+
async def my_function_async():
74+
data.append("async")
75+
76+
@async_dispatch(my_function_async)
77+
def my_function():
78+
data.append("sync")
79+
80+
# Even though we're in async context, .sync should force sync execution
81+
my_function.sync()
82+
assert data == ["sync"]
83+
84+
def test_async_compatible_fn_explicit_sync_usage_in_sync_context(self):
85+
"""Verify .sync property works as expected in sync context"""
86+
data = []
87+
88+
async def my_function_async():
89+
data.append("async")
90+
91+
@async_dispatch(my_function_async)
92+
def my_function():
93+
data.append("sync")
94+
95+
my_function.sync()
96+
assert data == ["sync"]
97+
98+
99+
class TestAsyncDispatchValidation:
100+
def test_async_compatible_requires_async_implementation(self):
101+
"""Verify we properly reject non-async implementations"""
102+
103+
def not_async():
104+
pass
105+
106+
with pytest.raises(TypeError, match="async_impl must be an async function"):
107+
108+
@async_dispatch(not_async)
109+
def my_function():
110+
pass
111+
112+
async def test_async_compatible_fn_attributes_exist(self):
113+
"""Verify both .sync and .aio attributes are present"""
114+
115+
async def my_function_async():
116+
pass
117+
118+
@async_dispatch(my_function_async)
119+
def my_function():
120+
pass
121+
122+
assert hasattr(my_function, "sync"), "Should have .sync attribute"
123+
assert hasattr(my_function, "aio"), "Should have .aio attribute"
124+
assert (
125+
my_function.sync is my_function.__wrapped__
126+
), "Should reference original sync function"
127+
assert (
128+
my_function.aio is my_function_async
129+
), "Should reference original async function"
130+
131+
132+
class TestAsyncCompatibleFnCannotBeUsedWithAsyncioRun:
133+
def test_async_compatible_fn_in_sync_context_errors_with_asyncio_run(self):
134+
"""this is here to illustrate the expected behavior"""
135+
data = []
136+
137+
async def my_function_async():
138+
data.append("async")
139+
140+
@async_dispatch(my_function_async)
141+
def my_function():
142+
data.append("sync")
143+
144+
with pytest.raises(ValueError, match="coroutine was expected, got None"):
145+
asyncio.run(my_function())
146+
147+
async def test_async_compatible_fn_in_async_context_fails_with_asyncio_run(self):
148+
"""this is here to illustrate the expected behavior"""
149+
data = []
150+
151+
async def my_function_async():
152+
data.append("async")
153+
154+
@async_dispatch(my_function_async)
155+
def my_function():
156+
data.append("sync")
157+
158+
with pytest.raises(
159+
RuntimeError, match="cannot be called from a running event loop"
160+
):
161+
asyncio.run(my_function())
162+
163+
164+
class TestIsInAsyncContext:
165+
async def test_is_in_async_context_from_coroutine(self):
166+
"""Verify detection inside a coroutine"""
167+
assert is_in_async_context() is True
168+
169+
def test_is_in_async_context_from_sync(self):
170+
"""Verify detection in pure sync context"""
171+
assert is_in_async_context() is False
172+
173+
async def test_is_in_async_context_with_nested_sync_in_worker_thread(self):
174+
def sync_func():
175+
return is_in_async_context()
176+
177+
assert await run_sync_in_worker_thread(sync_func) is False
178+
179+
def test_is_in_async_context_with_running_loop(self):
180+
"""Verify detection with just a running event loop"""
181+
loop = asyncio.new_event_loop()
182+
asyncio.set_event_loop(loop)
183+
result = None
184+
185+
def check_context():
186+
nonlocal result
187+
result = is_in_async_context()
188+
loop.stop()
189+
190+
try:
191+
loop.call_soon(check_context)
192+
loop.run_forever()
193+
assert (
194+
result is True
195+
), "the result we captured while loop was running should be True"
196+
finally:
197+
loop.close()
198+
asyncio.set_event_loop(None)
199+
assert (
200+
is_in_async_context() is False
201+
), "the loop should be closed and not considered an async context"

0 commit comments

Comments
 (0)