diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3a45a80..ca80840 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -2,17 +2,18 @@ name: Testing on: pull_request: - branches: [develop] + branches: [develop, main] jobs: - run-tests: - runs-on: ubuntu-latest - + run-unit-tests: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + os: [ubuntu-latest, windows-latest, macos-latest] fail-fast: false + runs-on: ${{ matrix.os }} + steps: - name: Checkout uses: actions/checkout@v5 @@ -25,11 +26,40 @@ jobs: - name: Setup uv uses: astral-sh/setup-uv@v7 - - name: Create virtual environment - run: uv venv .venv && source .venv/bin/activate + # - name: Create virtual environment + # run: uv venv .venv && source .venv/bin/activate - name: Install modules run: uv sync - - name: Run tests for Python ${{ matrix.python-version }} - run: uv run pytest + - name: Run unit tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/unit + + run-integration-tests: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Install modules + run: uv sync --extra redis --extra aiopika + + - name: Setup Docker containers + run: docker compose -f ./docker-compose-tests.yml up --wait + + - name: Run integration tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/integration diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml new file mode 100644 index 0000000..ffa0fbd --- /dev/null +++ b/docker-compose-tests.yml @@ -0,0 +1,20 @@ +services: + redis: + image: redis:latest + ports: + - "6379:6379" + + rabbitmq: + image: rabbitmq:latest + environment: + - RABBITMQ_DEFAULT_USER=guest + - RABBITMQ_DEFAULT_PASSWORD=guest + hostname: localhost + ports: + - "5672:5672" + healthcheck: + test: "rabbitmq-diagnostics check_running -q" + interval: 5s + timeout: 5s + retries: 10 + start_period: 5s \ No newline at end of file diff --git a/examples/counter/main.py b/examples/counter/main.py index a148aae..61f9789 100644 --- a/examples/counter/main.py +++ b/examples/counter/main.py @@ -1,7 +1,7 @@ import asyncio from taskiq_redis import PubSubBroker, RedisAsyncResultBackend -from taskiq_cancellation.integrations.redis import RedisCancellationBackend +from taskiq_cancellation.backends.redis import RedisCancellationBackend url = "redis://localhost" diff --git a/pyproject.toml b/pyproject.toml index 869d263..1df4f56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,14 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["taskiq", "typing-extensions>=4.13.2"] +dependencies = [ + "async-timeout>=5.0.1 ; python_full_version < '3.11'", + "taskiq", + "typing-extensions>=4.15.0 ; python_full_version < '3.11'", +] [project.optional-dependencies] -redis = ["redis~=3.0"] +redis = ["redis~=6.0"] aiopika = ["aio_pika"] [project.urls] @@ -44,8 +48,12 @@ check = "mypy --install-types --non-interactive {args:src/taskiq_cancellation te [tool.mypy] ignore_missing_imports = true -exclude = ["examples"] - +exclude = [ + # Added so that "mypy ." would work + "examples", + # Contains Python 3.11+ code, has to be excluded. Runtime checks don't work. + "src/taskiq_cancellation/cancellation_handlers/edge_3_11.py", +] [tool.coverage.run] source_pkgs = ["taskiq_cancellation", "tests"] @@ -70,3 +78,8 @@ dev = [ "pytest-asyncio>=0.24.0", "ruff>=0.14.4", ] + +[tool.ruff] +# Edge cancellation is currently Python 3.11+ which causes linter to freak out +# For such versioning cases tests must be written +target-version = "py314" diff --git a/src/taskiq_cancellation/__init__.py b/src/taskiq_cancellation/__init__.py index 57864e4..db24b56 100644 --- a/src/taskiq_cancellation/__init__.py +++ b/src/taskiq_cancellation/__init__.py @@ -2,7 +2,4 @@ from .backends.modular import ModularCancellationBackend -__all__ = [ - "CancellationBackend", - "ModularCancellationBackend" -] +__all__ = ["CancellationBackend", "ModularCancellationBackend"] diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py index 1decfa6..45239e2 100644 --- a/src/taskiq_cancellation/abc/__init__.py +++ b/src/taskiq_cancellation/abc/__init__.py @@ -1,6 +1,12 @@ from .backend import CancellationBackend from .notifier import CancellationNotifier from .state_holder import CancellationStateHolder +from .started_listening_event import StartedListeningEvent -__all__ = ["CancellationBackend", "CancellationNotifier", "CancellationStateHolder"] +__all__ = [ + "CancellationBackend", + "CancellationNotifier", + "CancellationStateHolder", + "StartedListeningEvent", +] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index c3befb9..6f20b5e 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,34 +1,44 @@ import abc -import asyncio -from typing import Callable, Annotated, TypeVar, Awaitable -from typing_extensions import ParamSpec, Self +import sys +import inspect +from typing import Callable, Annotated, overload, Optional, cast, Union + +if sys.version_info >= (3, 11): + from typing import Self, ParamSpec, TypeVar +else: + from typing_extensions import Self, ParamSpec, TypeVar -import anyio -from anyio.abc import TaskStatus from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState from taskiq_cancellation.utils import combines -from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.cancellation_handlers import ( + CancellationType, + LevelCancellationHandler, + EdgeCancellationHandler, +) + +from .started_listening_event import StartedListeningEvent -P = ParamSpec("P") -R = TypeVar("R") +Params = ParamSpec("Params") +Result = TypeVar("Result") class CancellationBackend(abc.ABC): """ Base class for cancellation backend """ + def __init__(self) -> None: super().__init__() - self.broker: AsyncBroker | None = None + self.broker: Union[AsyncBroker, None] = None @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: """ Checks if a task with task id of *task_id* is set to be cancelled - + :param task_id: task id to check :type task_id: str :returns: task cancellation state @@ -48,19 +58,19 @@ async def cancel(self, task_id: str) -> None: @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: """ Listens for cancellation messages and raises :ref:`TaskCancellationException` when receives :ref:`CancellationMessage` with same id as *task_id*. - This function is used in :func:`cancellable` decorator. - Call `started_listening_task_status.started()` when the listener is ready + This function is used in :func:`cancellable` decorator. + Call `started_listening_task_status.started()` when the listener is ready to receive messages. - :param task_id: id of task that will be listened for + :param task_id: id of task that will be listened for :type task_id: str - :param started_listening_task_status: + :param started_listening_task_status: :type started_listening_task_status: anyio.abc.TaskStatus """ pass @@ -75,7 +85,7 @@ async def startup(self) -> None: async def shutdown(self) -> None: """Shuts down cancellation backend - + Triggered only if backend has a broker set. To set a broker use :ref:`with_broker`. """ pass @@ -84,7 +94,7 @@ def with_broker(self, broker: AsyncBroker) -> Self: """ Set a broker and return updated cancellation backend - Sets up startup and shutdown event handlers for backend's startup + Sets up startup and shutdown event handlers for backend's startup and shutdown methods respectfully :param broker: new broker @@ -121,7 +131,19 @@ def with_broker(self, broker: AsyncBroker) -> Self: return self - def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + @overload + def cancellable( + self, cancellation_type: Callable[Params, Result] + ) -> Callable[Params, Result]: + pass + + @overload + def cancellable( + self, cancellation_type: Optional[CancellationType] = None + ) -> Callable[[Callable[Params, Result]], Callable[Params, Result]]: + pass + + def cancellable(self, cancellation_type=None): """ Decorator that makes funcion cancellable @@ -133,73 +155,56 @@ def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[ - Raises :ref:`TaskCancellationException` if listener task receives cancellation message - If listener task raises an exception, task is cancelled and exception is propogated upwards - :param task: Task function to wrap + :param cancellation_type: type of cancellation used + :type cancellation_type: CancellationType :returns: Cancellable task function """ - # Executor type depends on receiver configuration which we can't accessed in any way - if not asyncio.iscoroutinefunction(task): - raise ValueError("Can't cancel synchronous function") - - @combines(task) - async def wrapper( - *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs - ): - task_id = __taskiq_context.message.task_id - result = None - - listener_exception: Exception | None = None - task_exception: Exception | None = None - cancelled_by_request: bool = False - - async with anyio.create_task_group() as group: - - async def listen_for_cancellation( - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ): - nonlocal listener_exception, cancelled_by_request - - try: - await self.listen_for_cancellation(task_id, task_status) - except TaskCancellationException: - cancelled_by_request = True - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - listener_exception = e - finally: - group.cancel_scope.cancel() - - async def call_task(): - nonlocal result, task_exception - - try: - result = await task(*args, **kwargs) - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - task_exception = e - finally: - group.cancel_scope.cancel() - - # Listen before checking for cancellation in state holder - # so the message won't get lost in non-persistent queues - await group.start(listen_for_cancellation) - if await self.is_cancelled(task_id): - cancelled_by_request = True - group.cancel_scope.cancel() - else: - group.start_soon(call_task) - - if task_exception is not None: - raise task_exception - elif cancelled_by_request: - raise TaskCancellationException() - elif listener_exception is not None: - raise listener_exception - else: - return result - - return wrapper + + defaults = {"cancellation_type": CancellationType.LEVEL} + + def make_decorator(cancellation_type: CancellationType): + def decorator( + task: Callable[Params, Result], / + ) -> Callable[Params, Result]: + # Executor type depends on receiver configuration which we can't accessed in any way + if not inspect.iscoroutinefunction(task): + raise ValueError("Can't cancel synchronous function") + + @combines(task) + async def wrapper( + *args, + __taskiq_context: Annotated[Context, TaskiqDepends(Context)] = None, # type: ignore + **kwargs, + ) -> Result: + if __taskiq_context is None: + # Ran the function directly, without kiq + return await task(*args, **kwargs) + + task_id = __taskiq_context.message.task_id + + if cancellation_type is CancellationType.EDGE: + edge_handler = EdgeCancellationHandler(self, task, task_id) + return await edge_handler(*args, **kwargs) + elif cancellation_type is CancellationType.LEVEL: + level_handler = LevelCancellationHandler(self, task, task_id) + return await level_handler(*args, **kwargs) + else: + raise ValueError( + f"Unknown cancellation type: {cancellation_type!r}" + ) + + # Wrapper adds a key-word only param with default value + casted_wrapper = cast(Callable[Params, Result], wrapper) + return casted_wrapper + + return decorator + + if callable(cancellation_type): + return make_decorator(**defaults)(cancellation_type) + else: + return make_decorator( + cancellation_type=cancellation_type or defaults["cancellation_type"] + ) async def _broker_startup_handler(self, _: TaskiqState) -> None: await self.startup() diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 3ba408f..60b1d14 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -1,15 +1,22 @@ +import sys import abc -from anyio.abc import TaskStatus +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + from taskiq.abc.serializer import TaskiqSerializer from taskiq.serializers import JSONSerializer +from .started_listening_event import StartedListeningEvent + class CancellationNotifier(abc.ABC): """Receives cancellation messages and notifies listeners of these messages""" - def __init__(self, serializer: TaskiqSerializer = JSONSerializer()): - self.serializer = serializer + def __init__(self): + self.serializer = JSONSerializer() async def startup(self) -> None: """Starts up cancellation notifier""" @@ -31,19 +38,30 @@ async def cancel(self, task_id: str) -> None: @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: """ Listens for cancellation messages and raises :ref:`TaskCancellationException` when receives :ref:`CancellationMessage` with same id as *task_id*. - This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. - Call `started_listening_task_status.started()` when the listener is ready - to receive messages. + This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. + Call `started_listening_event.set()` when the listener is ready to receive messages. - :param task_id: id of task that will be listened for + :param task_id: id of task that will be listened for :type task_id: str - :param started_listening_task_status: - :type started_listening_task_status: anyio.abc.TaskStatus + :param started_listening_event: "listener started listening" confirmation event + :type started_listening_event: StartedListeningEvent """ pass + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.serializer = serializer + + return self diff --git a/src/taskiq_cancellation/abc/started_listening_event.py b/src/taskiq_cancellation/abc/started_listening_event.py new file mode 100644 index 0000000..4e43acf --- /dev/null +++ b/src/taskiq_cancellation/abc/started_listening_event.py @@ -0,0 +1,22 @@ +import abc + + +class StartedListeningEvent(abc.ABC): + """ + A confirmation event for listeners to mark that they started listening to messages. API is + similar to :ref:`asyncio.Event`. + + This is needed for different cancellation types: + - Level cancellation uses :ref:`anyio.abc.TaskStatus` + - Edge cancellation uses :ref:`asyncio.Event` + """ + + @abc.abstractmethod + async def set(self): + """Sets the event""" + pass + + @abc.abstractmethod + async def wait(self): + """Waits for the event to be set""" + pass diff --git a/src/taskiq_cancellation/abc/state_holder.py b/src/taskiq_cancellation/abc/state_holder.py index 8416440..bb2a707 100644 --- a/src/taskiq_cancellation/abc/state_holder.py +++ b/src/taskiq_cancellation/abc/state_holder.py @@ -18,7 +18,7 @@ async def cancel(self, task_id: str) -> None: async def is_cancelled(self, task_id: str) -> bool: """ Checks if a task with task id of *task_id* is set to be cancelled - + :param task_id: task id to check :type task_id: str :returns: task cancellation state diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py index ee5e385..795ac6e 100644 --- a/src/taskiq_cancellation/backends/in_memory.py +++ b/src/taskiq_cancellation/backends/in_memory.py @@ -5,8 +5,14 @@ class InMemoryCancellationBackend(ModularCancellationBackend): + """ + Cancellation backend that stores state and notifications in memory + + Useful for testing purposes + """ + def __init__(self, **kwargs): super().__init__( state_holder=InMemoryCancellationStateHolder(**kwargs), - notifier=InMemoryCancellationNotifier(**kwargs) + notifier=InMemoryCancellationNotifier(**kwargs), ) diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py index 2d87dfe..c65c00c 100644 --- a/src/taskiq_cancellation/backends/modular.py +++ b/src/taskiq_cancellation/backends/modular.py @@ -1,17 +1,31 @@ -from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from taskiq.abc.serializer import TaskiqSerializer + +from taskiq_cancellation.abc import ( + CancellationBackend, + CancellationNotifier, + CancellationStateHolder, + StartedListeningEvent, +) import anyio -from anyio.abc import TaskStatus class ModularCancellationBackend(CancellationBackend): """ - Modular cancellation backend made up of :class:`CancellationStateHolder` + Modular cancellation backend made up of :class:`CancellationStateHolder` and :class:`CancellationNotifier` - - `CancellationStateHolder` stores cancellation state and blocks the task from being run. + - `CancellationStateHolder` stores cancellation state and blocks the task from being run. - `CancellationNotifier` receives cancellation messages and cancels already running tasks. """ + def __init__( self, state_holder: CancellationStateHolder, notifier: CancellationNotifier ): @@ -29,18 +43,28 @@ async def cancel(self, task_id: str): group.start_soon(self.notifier.cancel, task_id) async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus[None] + self, task_id: str, started_listening_event: StartedListeningEvent ): - await self.notifier.listen_for_cancellation( - task_id, started_listening_task_status - ) + await self.notifier.listen_for_cancellation(task_id, started_listening_event) async def startup(self) -> None: await super().startup() await self.state_holder.startup() await self.notifier.startup() - + async def shutdown(self) -> None: await super().shutdown() await self.state_holder.shutdown() await self.notifier.shutdown() + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.notifier = self.notifier.with_serializer(serializer) + + return self diff --git a/src/taskiq_cancellation/cancellation_handlers/__init__.py b/src/taskiq_cancellation/cancellation_handlers/__init__.py new file mode 100644 index 0000000..a228c8d --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/__init__.py @@ -0,0 +1,12 @@ +import sys + +from .cancellation_type import CancellationType +from .level import LevelCancellationHandler + +if sys.version_info >= (3, 11): + from .edge_3_11 import EdgeCancellationHandler +else: + from .edge_non_supported import EdgeCancellationHandler + + +__all__ = ["CancellationType", "LevelCancellationHandler", "EdgeCancellationHandler"] diff --git a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py new file mode 100644 index 0000000..9714e67 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py @@ -0,0 +1,8 @@ +import enum + + +class CancellationType(str, enum.Enum): + """Type of cancellation used by the backend""" + + EDGE = "edge" + LEVEL = "level" diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py new file mode 100644 index 0000000..4bf66b5 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py @@ -0,0 +1,136 @@ +# FIXME: bunch of issues because of Python 3.11+ exclusivity +# +# Edge cancellation handler is using asyncio.TaskGroup which was introduced in 3.11 and +# uses expect-star syntax that was also introduced in the same version +# - mypy has to ignore this file because it can't finish static parsing +# - ruff's python version has to be set at 3.11+ so it wouldn't complain +# +# I'm not sure how to mitigate these issues. Maybe this can be put in a separate module somehow +# and then integrated? Maybe this can be rewritten to not use TaskGroup (probably easier to do)? + +import sys +import logging +import asyncio +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.utils import StopTaskGroupException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. That means :ref:`asyncio.CancelledError` is + raised only once for the task. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + + Currently is supported in Python 3.11+ due to using :ref:`asyncio.TaskGroup`. + """ + + class ListeningEvent(StartedListeningEvent): + def __init__(self) -> None: + self.event = asyncio.Event() + + async def set(self): + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self.event.set) + + async def wait(self): + await self.event.wait() + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Result = None # type: ignore + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async def listen_for_cancellation(event: StartedListeningEvent): + nonlocal listener_exception, cancelled_by_request + + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + raise + except asyncio.CancelledError: + raise + except Exception as e: + listener_exception = e + raise + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except asyncio.CancelledError: + raise + except Exception as e: + task_exception = e + raise + + try: + async with asyncio.TaskGroup() as tg: + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + event = self.ListeningEvent() + tg.create_task(listen_for_cancellation(event)) + await event.wait() + + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + raise StopTaskGroupException() + + task_task = asyncio.create_task(call_task()) + await task_task + if not task_task.cancelled(): + raise StopTaskGroupException() + except* StopTaskGroupException: + pass + except* Exception as exc_group: + uncaught_exceptions = list( + filter( + lambda e: e == task_exception or e == listener_exception, + exc_group.exceptions, + ) + ) + + if uncaught_exceptions: + logging.log(logging.ERROR, "Uncaught exceptions in TaskGroup") + for e in uncaught_exceptions: + logging.exception(e) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py new file mode 100644 index 0000000..03d06fd --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py @@ -0,0 +1,34 @@ +import sys +from typing import Callable, TYPE_CHECKING, Coroutine, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. Currently is supported in Python 3.11+ due + to using :ref:`asyncio.TaskGroup`. + """ + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ) -> None: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py new file mode 100644 index 0000000..2cfe71a --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/level.py @@ -0,0 +1,110 @@ +import sys +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Union, cast, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +import anyio +from anyio.abc import TaskStatus + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class LevelCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses level cancellation provided by anyio. That means cancellation exception is raised + on every await in the coroutine. + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ + + class ListeningEvent(StartedListeningEvent): + def __init__(self, task_status: TaskStatus) -> None: + self.task_status = task_status + + async def set(self): + self.task_status.started() + + async def wait(self): + # Can ignore, won't execute further before task status is set + pass + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Union[Result, None] = None + + listener_exception: Union[Exception, None] = None + task_exception: Union[Exception, None] = None + cancelled_by_request: bool = False + + async with anyio.create_task_group() as group: + + async def listen_for_cancellation( + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ): + nonlocal listener_exception, cancelled_by_request + + event = self.ListeningEvent(task_status) + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + listener_exception = e + finally: + group.cancel_scope.cancel() + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + task_exception = e + finally: + group.cancel_scope.cancel() + + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + await group.start(listen_for_cancellation) + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + group.cancel_scope.cancel() + else: + group.start_soon(call_task) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + # If the task is finished, it is definitely not None + result = cast(Result, result) + return result diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py index 8e55ab8..2432e56 100644 --- a/src/taskiq_cancellation/notifiers/aiopika.py +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -10,52 +10,67 @@ class AioPikaNotifier(QueueCancellationNotifier): + """Notifier for RabbitMQ using aio-pika""" + EXCHANGE_NAME = "__taskiq_cancellation" - def __init__(self, url: str, **kwargs): - super().__init__(**kwargs) + def __init__(self, url: str, **connection_kwargs): + """ + Creates AioPika notifier + + :param url: url to rabbitmq + :type url: str + :param connection_kwargs: arguments for :ref:`aio_pika.connect_robust` + """ + super().__init__() self.url: str = url + self.connection_kwargs = connection_kwargs async def cancel(self, task_id: str) -> None: timestamp = time.time() - connection = await aio_pika.connect_robust(self.url) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - - await exchange.publish( - aio_pika.Message( - body=self.serializer.dumpb( - model_dump( - CancellationMessage(task_id=task_id, timestamp=timestamp) + + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + + await exchange.publish( + aio_pika.Message( + body=self.serializer.dumpb( + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) ) - ) - ), - routing_key="", - ) + ), + routing_key="", + ) async def _listen(self, started_listening: asyncio.Event): - connection = await aio_pika.connect_robust(self.url) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - queue = await channel.declare_queue(exclusive=True, auto_delete=True) - await queue.bind(exchange) - - loop = asyncio.get_running_loop() - loop.call_soon_threadsafe(started_listening.set) - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - cancellation_message = model_validate( - CancellationMessage, self.serializer.loadb(message.body) - ) - - for queue in self.queues: - await queue.put(cancellation_message) - await message.ack() + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + queue = await channel.declare_queue(exclusive=True, auto_delete=True) + await queue.bind(exchange) + + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + cancellation_message = model_validate( + CancellationMessage, self.serializer.loadb(message.body) + ) + + for subscriber_queue in self.queues: + await subscriber_queue.put(cancellation_message) + await message.ack() diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py index 748a129..1dfe59f 100644 --- a/src/taskiq_cancellation/notifiers/in_memory.py +++ b/src/taskiq_cancellation/notifiers/in_memory.py @@ -1,5 +1,6 @@ import time import asyncio +from typing import Union from taskiq_cancellation.message import CancellationMessage @@ -11,20 +12,25 @@ class InMemoryCancellationNotifier(QueueCancellationNotifier): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - - self.messages: asyncio.Queue[CancellationMessage] = asyncio.Queue() + + # In Python 3.9 queues must be created inside a running loop + # Source: https://stackoverflow.com/questions/53724665 + self.messages: Union[asyncio.Queue[CancellationMessage], None] = None async def cancel(self, task_id: str) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + timestamp = time.time() - + await self.messages.put( - CancellationMessage( - task_id=task_id, - timestamp=timestamp - ) + CancellationMessage(task_id=task_id, timestamp=timestamp) ) async def _listen(self, started_listening: asyncio.Event) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + loop = asyncio.get_running_loop() loop.call_soon_threadsafe(started_listening.set) diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py index 9740012..bb19211 100644 --- a/src/taskiq_cancellation/notifiers/null.py +++ b/src/taskiq_cancellation/notifiers/null.py @@ -1,19 +1,20 @@ import asyncio -from anyio.abc import TaskStatus -from taskiq_cancellation.abc.notifier import CancellationNotifier +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent class NullCancellationNotifier(CancellationNotifier): """ \"Do nothing\" cancellation notifier - + May be useful if there's no need or ability to use an actual notifier """ - + async def cancel(self, task_id: str) -> None: pass - async def listen_for_cancellation(self, task_id: str, started_listening_task_status: TaskStatus) -> None: - started_listening_task_status.started() + async def listen_for_cancellation( + self, task_id: str, started_listening_event: StartedListeningEvent + ) -> None: + await started_listening_event.set() await asyncio.sleep(float("+inf")) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py index 3190501..16ea174 100644 --- a/src/taskiq_cancellation/notifiers/queue.py +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -1,10 +1,9 @@ import abc import weakref import asyncio +from typing import Union -from anyio.abc import TaskStatus - -from taskiq_cancellation.abc import CancellationNotifier +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent from taskiq_cancellation.exceptions import TaskCancellationException from taskiq_cancellation.message import CancellationMessage @@ -16,10 +15,11 @@ class QueueCancellationNotifier(CancellationNotifier): Requires :func:`_listen` to be implemeted """ + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.listener_task: asyncio.Task | None = None + self.listener_task: Union[asyncio.Task, None] = None self.queues: weakref.WeakSet[asyncio.Queue[CancellationMessage]] = ( weakref.WeakSet() ) @@ -28,9 +28,10 @@ def __init__(self, **kwargs) -> None: async def shutdown(self) -> None: if self.listener_task is not None: self.listener_task.cancel() + await asyncio.wait([self.listener_task]) async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: cancellations: asyncio.Queue[CancellationMessage] = asyncio.Queue() @@ -38,7 +39,7 @@ async def listen_for_cancellation( await self._create_listener_task() await self._subscribe(cancellations) - started_listening_task_status.started() + await started_listening_event.set() while True: cancellation_message = await cancellations.get() diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py index 2efa7c2..fe7746e 100644 --- a/src/taskiq_cancellation/notifiers/redis.py +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -10,12 +10,21 @@ class PubSubCancellationNotifier(QueueCancellationNotifier): + """Cancellation notifier using Redis pub/sub""" + CHANNEL_NAME = "__taskiq_cancellation_notifications" - def __init__(self, url: str, **kwargs) -> None: - super().__init__(**kwargs) + def __init__(self, url: str, **connection_kwargs) -> None: + """ + Creates AioPika notifier - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs) + :param url: url to redis + :type url: str + :param connection_kwargs: arguments for :ref:`redis.BlockingConnectionPool.from_url` + """ + super().__init__() + + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) async def cancel(self, task_id: str) -> None: timestamp = time.time() @@ -54,3 +63,8 @@ async def _listen(self, started_listening: asyncio.Event): ) for queue in self.queues: await queue.put(cancellation_message) + + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + \ No newline at end of file diff --git a/src/taskiq_cancellation/state_holders/in_memory.py b/src/taskiq_cancellation/state_holders/in_memory.py index 23a7089..b63277e 100644 --- a/src/taskiq_cancellation/state_holders/in_memory.py +++ b/src/taskiq_cancellation/state_holders/in_memory.py @@ -3,10 +3,10 @@ class InMemoryCancellationStateHolder(CancellationStateHolder): """In memory cancellation state holder used for testing""" - + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - + self.state_holder: dict[str, bool] = {} async def cancel(self, task_id: str) -> None: diff --git a/src/taskiq_cancellation/state_holders/null.py b/src/taskiq_cancellation/state_holders/null.py index 4383071..b24462d 100644 --- a/src/taskiq_cancellation/state_holders/null.py +++ b/src/taskiq_cancellation/state_holders/null.py @@ -4,10 +4,10 @@ class NullCancellationStateHolder(CancellationStateHolder): """ \"Do nothing\" cancellation state holder - + May be useful if there's no need or ability to use an actual state holder """ - + async def cancel(self, task_id: str) -> None: pass diff --git a/src/taskiq_cancellation/state_holders/redis.py b/src/taskiq_cancellation/state_holders/redis.py index 262f121..be86f08 100644 --- a/src/taskiq_cancellation/state_holders/redis.py +++ b/src/taskiq_cancellation/state_holders/redis.py @@ -17,5 +17,10 @@ async def is_cancelled(self, task_id: str) -> bool: response = await conn.get(self._task_key(task_id)) return bool(response) + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + def _task_key(self, task_id: str) -> str: return f"__cancellation_status_{task_id}" + diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index 4d659f0..a128266 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -5,7 +5,9 @@ from collections import OrderedDict -def combines(wrapped): +def combines( + wrapped: typing.Callable, add_var_parameters: bool = False +) -> typing.Callable: """ Combines wrapped and wrapper functions signatures and type hints @@ -28,6 +30,11 @@ def foo(a: int, b = "lol"): print(inspect.signature(foo)) # (a: int, c: int, b='lol', *args, **kwargs) ''' + + :param wrapped: function to be wrapped + :type wrapped: Callable + :param add_var_parameters: add *args and **kwargs from wrapper to new signature + :type add_var_parameters: bool """ wrapped_signature: inspect.Signature = inspect.signature(wrapped) wrapped_type_hints: typing.Dict[str, str] = typing.get_type_hints(wrapped) @@ -42,9 +49,19 @@ def decorator(wrapper): f"Parameter {param_name} will be overwritten by wrapper function" ) - parameters = OrderedDict( - wrapped_signature.parameters, **wrapper_signature.parameters - ) + wrapper_parameters = OrderedDict() + for name, parameter in wrapper_signature.parameters.items(): + if not add_var_parameters: + if any( + ( + parameter.kind is inspect.Parameter.VAR_POSITIONAL, + parameter.kind is inspect.Parameter.VAR_KEYWORD, + ) + ): + continue + wrapper_parameters[name] = parameter + + parameters = OrderedDict(wrapped_signature.parameters, **wrapper_parameters) parameters = sorted( parameters.values(), key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0), @@ -70,4 +87,8 @@ def decorator(wrapper): return decorator +class StopTaskGroupException(Exception): + pass + + __all__ = ["combines"] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/__init__.py b/tests/integration/aiopika/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/conftest.py b/tests/integration/aiopika/conftest.py new file mode 100644 index 0000000..53e27dc --- /dev/null +++ b/tests/integration/aiopika/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def rabbitmq_url(): + return os.environ.get("TEST_RABBITMQ_URL", "amqp://guest:guest@localhost:5672") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/aiopika/test_notifier.py b/tests/integration/aiopika/test_notifier.py new file mode 100644 index 0000000..11e6553 --- /dev/null +++ b/tests/integration/aiopika/test_notifier.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.aiopika import AioPikaNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def notifier(rabbitmq_url): + return AioPikaNotifier(url=rabbitmq_url) + + +@pytest.mark.asyncio +async def test_cancellation(notifier: AioPikaNotifier): + await run_notifier_cancellation_test(notifier) diff --git a/tests/integration/common/cancellations.py b/tests/integration/common/cancellations.py new file mode 100644 index 0000000..e495604 --- /dev/null +++ b/tests/integration/common/cancellations.py @@ -0,0 +1,94 @@ +import sys +import uuid +import pytest +import asyncio + +from taskiq import InMemoryBroker + +from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder +from taskiq_cancellation.backends.modular import ModularCancellationBackend +from taskiq_cancellation.notifiers.null import NullCancellationNotifier +from taskiq_cancellation.state_holders.null import NullCancellationStateHolder +from taskiq_cancellation.exceptions import TaskCancellationException + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + + + +async def run_backend_cancellation_test(backend: CancellationBackend): + broker = InMemoryBroker() + backend = backend.with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + pass + + await broker.startup() + + task = await test_task.kiq() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_notifier_cancellation_test(notifier: CancellationNotifier): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + NullCancellationStateHolder(), + notifier + ).with_broker(broker) + + task_started = asyncio.Event() + + @broker.task + @backend.cancellable + async def test_task(): + task_started.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + async with timeout(1): + await task_started.wait() + + await backend.cancel(task.task_id) + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_state_holder_cancellation_test(state_holder: CancellationStateHolder): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + state_holder, + NullCancellationNotifier() + ).with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + # This task is supposed to never start + assert False + + await broker.startup() + + task_id = str(uuid.uuid4()) + await backend.cancel(task_id) + task = await test_task.kicker().with_task_id(task_id).kiq() + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() diff --git a/tests/integration/redis/__init__.py b/tests/integration/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/redis/conftest.py b/tests/integration/redis/conftest.py new file mode 100644 index 0000000..1823f42 --- /dev/null +++ b/tests/integration/redis/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def redis_url(): + return os.environ.get("TEST_REDIS_URL", "redis://localhost:6379") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/redis/test_backend.py b/tests/integration/redis/test_backend.py new file mode 100644 index 0000000..a5bde89 --- /dev/null +++ b/tests/integration/redis/test_backend.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.backends.redis import RedisCancellationBackend + +from ..common.cancellations import run_backend_cancellation_test + + +@pytest.fixture +def redis_backend(redis_url, broker): + return RedisCancellationBackend(url=redis_url).with_broker(broker) + + +@pytest.mark.asyncio +async def test_cancellation(redis_backend: RedisCancellationBackend): + await run_backend_cancellation_test(redis_backend) diff --git a/tests/integration/redis/test_pubsub.py b/tests/integration/redis/test_pubsub.py new file mode 100644 index 0000000..b49720c --- /dev/null +++ b/tests/integration/redis/test_pubsub.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def pubsub_notifier(redis_url): + return PubSubCancellationNotifier(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(pubsub_notifier: PubSubCancellationNotifier): + await run_notifier_cancellation_test(pubsub_notifier) diff --git a/tests/integration/redis/test_state_holder.py b/tests/integration/redis/test_state_holder.py new file mode 100644 index 0000000..f8be37b --- /dev/null +++ b/tests/integration/redis/test_state_holder.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder + +from ..common.cancellations import run_state_holder_cancellation_test + + +@pytest.fixture +def state_holder(redis_url): + return RedisCancellationStateHolder(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(state_holder: RedisCancellationStateHolder): + await run_state_holder_cancellation_test(state_holder) diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py deleted file mode 100644 index a790d1e..0000000 --- a/tests/test_cancellation.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -import asyncio - -from taskiq import AsyncBroker, InMemoryBroker - -from taskiq_cancellation.abc import CancellationBackend -from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend -from taskiq_cancellation.exceptions import TaskCancellationException - - -@pytest.fixture -def broker(): - return InMemoryBroker() - - -@pytest.fixture -def backend(broker): - return InMemoryCancellationBackend().with_broker(broker) - - -@pytest.mark.asyncio -async def test_task_success(broker: AsyncBroker, backend: CancellationBackend): - """Test that cancellable task can run successfully""" - - @broker.task - @backend.cancellable - async def task(): - await asyncio.sleep(0.1) - - await broker.startup() - - t = await task.kiq() - - result = await t.wait_result() - assert result.is_err is False - - await broker.shutdown() - - -@pytest.mark.asyncio -async def test_task_cancellation(broker: AsyncBroker, backend: CancellationBackend): - """Test that cancellable task can be cancelled""" - - @broker.task - @backend.cancellable - async def task(): - await asyncio.sleep(0.3) - - await broker.startup() - - t = await task.kiq() - await backend.cancel(t.task_id) - - with pytest.raises(TaskCancellationException): - result = await t.wait_result() - result.raise_for_error() - - await broker.shutdown() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py new file mode 100644 index 0000000..459492d --- /dev/null +++ b/tests/unit/test_backend.py @@ -0,0 +1,50 @@ +import pytest +import inspect + +from taskiq_cancellation.abc.backend import CancellationBackend +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend + + +@pytest.fixture +def backend(): + return InMemoryCancellationBackend() + + +@pytest.mark.asyncio +async def test_decorator_without_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works without parentesis""" + + @backend.cancellable + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + + +@pytest.mark.asyncio +async def test_decorator_with_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works with parentesis""" + + @backend.cancellable() + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + + +@pytest.mark.asyncio +async def test_decorator_with_sync_function(backend: CancellationBackend): + """ + Tests that cancellable decorator doesn't work with synchronous functions + + To launch a synchronous function we need to know how to do it and only Receiver knows that + """ + + with pytest.raises(ValueError): + @backend.cancellable + def test_task(): + pass diff --git a/tests/unit/test_cancellation.py b/tests/unit/test_cancellation.py new file mode 100644 index 0000000..892e988 --- /dev/null +++ b/tests/unit/test_cancellation.py @@ -0,0 +1,237 @@ +import sys +import pytest +import asyncio + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + +import anyio +from taskiq import AsyncBroker, InMemoryBroker + +from taskiq_cancellation.abc.backend import CancellationBackend, CancellationType +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend +from taskiq_cancellation.exceptions import TaskCancellationException + + +@pytest.fixture +def broker(): + return InMemoryBroker() + + +@pytest.fixture +def backend(broker: AsyncBroker): + return InMemoryCancellationBackend().with_broker(broker) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) +) +async def test_task_direct_call( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, +): + """Tests that cancellable function can be called directly""" + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + return True + + result = await test_task() + assert result + + +class TestTaskSuccess: + types = [CancellationType.LEVEL] + if sys.version_info >= (3, 11): + types.append(CancellationType.EDGE) + + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_success( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully finish""" + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False + + await broker.shutdown() + + +class TestTaskCancellation: + types = [CancellationType.LEVEL] + if sys.version_info > (3, 11): + types.append(CancellationType.EDGE) + + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_cancellation( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully cancel""" + + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + with pytest.raises(anyio.get_cancelled_exc_class()): + started_event.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +class TestEdgeCancellation: + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) + async def test_non_supported( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """Tests that edge cancellation raises NotImplementedError in Python <3.11""" + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + + with pytest.raises(NotImplementedError): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) + async def test_cancellation_interception( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that tasks can capture task cancellation + + asyncio raises asyncio.CancelledError only once. Task can intercept that to do cleanup, + but also can just ignore the cancellation request. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + """ + + cancelled_for_second_time = False + task_started = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + task_started.set() + await asyncio.sleep(0.5) + except asyncio.CancelledError: + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await task_started.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is False + + await broker.shutdown() + + +class TestLevelCancellation: + @pytest.mark.asyncio + async def test_repeated_cancellation( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that task will have multiple cancellation exceptions + + anyio raises cancellation exception on every await + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ + + cancelled_for_second_time = False + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.LEVEL) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + started_event.set() + await asyncio.sleep(1) + except anyio.get_cancelled_exc_class(): + # anyio cancels on any await after scope's cancellation + try: + await asyncio.sleep(0) + except anyio.get_cancelled_exc_class(): + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is True + + await broker.shutdown()