Skip to content

Ensure async fixture setup and teardown run in the same task #1193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import (
AsyncIterator,
Awaitable,
Coroutine as CoroutineT,
Generator,
Iterable,
Iterator,
Expand Down Expand Up @@ -276,6 +277,69 @@ def _fixture_synchronizer(
AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType")


def _create_task_in_context(
coro: CoroutineT[Any, Any, Any],
loop: AbstractEventLoop,
context: contextvars.Context,
) -> asyncio.Task[Any]:
if sys.version_info >= (3, 11):
return loop.create_task(coro, context=context)

from backports.asyncio.runner._patch import _patch_object
from backports.asyncio.runner.tasks import Task

with (
_patch_object(asyncio.tasks, asyncio.tasks.Task.__name__, Task),
_patch_object(contextvars, contextvars.copy_context.__name__, lambda: context),
):
return loop.create_task(coro)


class _FixtureRunner:
def __init__(self, loop: AbstractEventLoop, context: contextvars.Context) -> None:
self.loop = loop
self.queue: asyncio.Queue[tuple[Awaitable[Any], asyncio.Future[Any]] | None] = (
asyncio.Queue()
)
self._context = context
self._task = None

async def _worker(self) -> None:
while True:
item = await self.queue.get()
if item is None:
break
coro, future = item
try:
retval = await coro
future.set_result(retval)
except Exception as exc:
future.set_exception(exc)

def run(self, func):
return self.loop.run_until_complete(self._run(func))

async def _run(self, func):
if self._task is None:
self._task = _create_task_in_context(
self._worker(), loop=self.loop, context=self._context
)

coro = func()
future = self.loop.create_future()
self.queue.put_nowait((coro, future))
return await future

async def _stop(self):
self.queue.put_nowait(None)
if self._task is not None:
await self._task
self._task = None

def stop(self) -> None:
self.loop.run_until_complete(self._stop())


def _wrap_asyncgen_fixture(
fixture_function: Callable[
AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any]
Expand All @@ -295,7 +359,8 @@ async def setup():
return res

context = contextvars.copy_context()
result = runner.run(setup(), context=context)
fixture_runner = _FixtureRunner(loop=runner.get_loop(), context=context)
result = fixture_runner.run(setup)

reset_contextvars = _apply_contextvar_changes(context)

Expand All @@ -312,7 +377,8 @@ async def async_finalizer() -> None:
msg += "Yield only once."
raise ValueError(msg)

runner.run(async_finalizer(), context=context)
fixture_runner.run(async_finalizer)
fixture_runner.stop()
if reset_contextvars is not None:
reset_contextvars()

Expand Down
11 changes: 11 additions & 0 deletions tests/async_fixtures/test_async_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ async def async_fixture_method(self):
@pytest.mark.asyncio
async def test_async_fixture_method(self):
assert self.is_same_instance


@pytest.fixture()
async def setup_and_teardown_tasks():
task = asyncio.current_task()
yield
assert task is asyncio.current_task()


async def test_setup_and_teardown_tasks(setup_and_teardown_tasks):
pass
Loading