diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..cebff2b --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,33 @@ +name: Lint + +on: + pull_request: + branches: [develop, main] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: 3.14 + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment + run: uv venv .venv && source .venv/bin/activate + + - name: Install modules + run: uv sync + + - name: Check code style + run: uv run ruff check + + - name: Check static typing + run: uv run mypy . diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..9473ce2 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,28 @@ +name: Release on PyPI + +on: + release: + types: + - released + +jobs: + publish: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python 3.14 + uses: actions/setup-python@v6 + with: + python-version: 3.14 + + - name: Setup uv + uses: astral-sh/setup-uv@v6 + + - name: Build package + run: uv build + + - name: Publish package + run: uv publish diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 0000000..ca80840 --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,65 @@ +name: Testing + +on: + pull_request: + branches: [develop, main] + +jobs: + run-unit-tests: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + os: [ubuntu-latest, windows-latest, macos-latest] + fail-fast: false + + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + # - name: Create virtual environment + # run: uv venv .venv && source .venv/bin/activate + + - name: Install modules + run: uv sync + + - name: Run unit tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/unit + + run-integration-tests: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Install modules + run: uv sync --extra redis --extra aiopika + + - name: Setup Docker containers + run: docker compose -f ./docker-compose-tests.yml up --wait + + - name: Run integration tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/integration diff --git a/README.md b/README.md index aa1a64f..57abe86 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,172 @@ -[![PyPI - Version](https://img.shields.io/pypi/v/taskiq-cancellation.svg)](https://pypi.org/project/taskiq-cancellation) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/taskiq-cancellation.svg)](https://pypi.org/project/taskiq-cancellation) +
+ taskiq-cancellation logo +
-# Task cancellation for taskiq +[![PyPI - Version](https://img.shields.io/pypi/v/taskiq-cancellation.svg?style=for-the-badge)](https://pypi.org/project/taskiq-cancellation) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/taskiq-cancellation.svg?style=for-the-badge)](https://pypi.org/project/taskiq-cancellation) + +**[taskiq-cancellation](https://pypi.org/project/taskiq-cancellation)** aims to be a drop-in task cancellation solution for taskiq as the original package doesn't provide a cancellation API. + +## Contents: + +- [Installation](#installation) +- [Usage](#usage) + - [What is a cancellation backend?](#what-is-a-cancellation-backend) + - [Modular cancellation backend](#modular-cancellation-backend) + - [Available integrations](#available-integrations) + - [Level and edge cancellation](#level-and-edge-cancellation) + - [Retry middlewares with task cancellation](#retry-middlewares-with-task-cancellation) +- [Development](#development) +- [Contributing](#contributing) ## Installation -```console +This package can be install from PyPI with your package manager of choice. + +```bash pip install taskiq-cancellation +pipx install taskiq-cancellation +poetry add taskiq-cancellation +uv add taskiq-cancellation +``` + +taskiq-cancellation currently provides integrations with Redis and RabbitMQ that are installable with `redis` and `aiopika` extras respectfully. + +```bash +pip install taskiq-cancellation[redis,aiopika] ``` ## Usage + +To do task cancellation, you need to: + +1. Create a cancellation backend +2. Wrap a function with `cancellable` decorator +3. Cancel the task with `cancel(task_id)` + +```python +broker = PubSubBroker(url).with_result_backend(RedisAsyncResultBackend(url)) +cancellation_backend = RedisCancellationBackend(url).with_broker(broker) + +@broker.task +@cancellation_backend.cancellable +async def sleep(seconds: int): + await asyncio.sleep(seconds) + print("Slept!") # Won't be printed on worker side because of the cancellation + +async def main(): + await broker.startup() + + task = await sleep.kiq(5) + await cancellation_backend.cancel(task.task_id) + + await broker.shutdown() + +asyncio.run(main()) +``` + +### What is a cancellation backend? + +**Cancellation backend** can be seen as combination of a broker and result backend for cancellation messages that works underneath taskiq's broker. Cancellation backend won't run tasks marked as cancelled and will listen for cancellation messages for already running tasks. + +
+ Cancellation backend example scheme +
+ +### Modular cancellation backend + +To easily create cancellation backends taskiq-cancellation provides `ModularCancellationBackend`. Modular cancellation backend consists of two parts: state holder and notifier. + +- State holder is used to check for task cancellation status before running the task. +- Notifier is used to listen for cancellation messages while running the task + +This allows to use any techonology for task cancellation. For example, if one uses SQL database and RabbitMQ message broker, they can make a custom state holder with SQL library of their choice and use provided RabbitMQ notifier. + +```python +from taskiq_cancellation import ModularCancellationBackend +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder +from taskiq_cancellation.notifiers.aiopika import AioPikaCancellationNotifier + +backend = ModularCancellationBackend( + RedisCancellationStateHolder("redis://localhost:6379"), + AioPikaCancellationNotifier("amqp://guest:guest@localhost:5672") +) +``` + +### Available integrations + +taskiq-cancellation provides: + +- state holder for Redis (`RedisCancellationStateHolder`) +- notifiers for Redis pub/sub (`PubSubCancellationNotifier`) and RabbitMQ (`AioPikaCancellationNotifier`) + +Also there are `NullCancellationStateHolder` and `NullCancellationNotifier` that do absolutely nothing, if there's no need to not check for task cancellation before starting the task or no need to listen for cancellation of already running tasks. + +### Level and edge cancellation + +By default, taskiq-cancellation uses [`anyio`](https://anyio.readthedocs.io/en/stable/) and its [level cancellation](anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics). Level cancellation raises a cancellation exception on **every** asynchronous wait in a function. + +As external libraries might not support level cancellation, task-cancellation also provides [edge cancellation]() via `asyncio`. Edge cancellation raises an exception only _once_. To enable it, add `cancellation_type=CancellationType.EDGE` parameter to `cancellable` decorator. + +> [!WARNING] +> Currently edge cancellation is supported only for Python 3.11+ because it uses [`asyncio.TaskGroup`](https://docs.python.org/3/library/asyncio-task.html#asyncio.TaskGroup) + +Example: + +```python +from sqlalchemy.ext.asyncio import AsyncSession +from taskiq_cancellation import CancellationType + +@broker.task +@cancellation_backend.cancellable(cancellation_type=CancellationType.EDGE) +async def sleep(seconds: int): + session = AsyncSession(engine) + + try: + async with session.begin(): + await asyncio.sleep(seconds) + session.add(SleptFor(seconds)) + except asyncio.CancelledError: + # Won't raise cancelled exception + await session.close() + raise +``` + +### Retry middlewares with task cancellation + +If you use `SimpleRetryMiddleware` or `SmartRetryMiddleware`, make sure to add `TaskCancellationException` to `types_of_exceptions` parameter to not trigger additional retries. + +```python +from taskiq_cancellation.exceptions import TaskCancellationException + +broker = PubSubBroker(url) + .with_result_backend(RedisAsyncResultBackend(url)) + .with_middlewares( + SimpleRetryMiddleware( + types_of_exceptions=[TaskCancellationException, ] + ) + ) +``` + +## Development + +For linting, ruff is used + +```bash +ruff check +ruff format +``` + +For testing, pytest is used + +```bash +pytest tests/unit # Unit tests + +# Integration tests +docker compose -f docker-compose-tests.yml up --wait +pytest tests/integration +``` + +## Contributing + +If you have any issues with this package or have an idea for improvement, please don't hesitate to open an issue! This is my first open-source project so I would like to ask to be a little patient with me though 🙏 diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml new file mode 100644 index 0000000..ffa0fbd --- /dev/null +++ b/docker-compose-tests.yml @@ -0,0 +1,20 @@ +services: + redis: + image: redis:latest + ports: + - "6379:6379" + + rabbitmq: + image: rabbitmq:latest + environment: + - RABBITMQ_DEFAULT_USER=guest + - RABBITMQ_DEFAULT_PASSWORD=guest + hostname: localhost + ports: + - "5672:5672" + healthcheck: + test: "rabbitmq-diagnostics check_running -q" + interval: 5s + timeout: 5s + retries: 10 + start_period: 5s \ No newline at end of file diff --git a/examples/counter/README.md b/examples/counter/README.md new file mode 100644 index 0000000..f3bc5be --- /dev/null +++ b/examples/counter/README.md @@ -0,0 +1,63 @@ +# Example: Counter + +This example demonstrates use of Redis cancellation backend with Redis broker. + +`sleep(seconds)` task sleeps for `seconds` seconds. First task will finish fully after 5 seconds and second task will cancel after 2.5 seconds. + +## How to run + +1. Install dependencies + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +2. Launch redis server locally + +3. Launch worker + +```bash +taskiq worker main:broker --workers=1 +``` + +4. Launch client + +```bash +python main.py +``` + +## Expected result + +Client side: + +```console +Sending task and waiting 5 seconds... +Sending task and waiting 2.5 seconds... +Canceling task... +``` + +Worker side: + +```console +[2025-11-11 22:45:12,072][taskiq.receiver.receiver][INFO ][worker-0] Executing task main:count with ID: e2acdc76da94435d963b561b80098c47 +1 Mississippi +2 Mississippi +3 Mississippi +4 Mississippi +5 Mississippi +[2025-11-11 22:45:17,073][taskiq.receiver.receiver][INFO ][worker-0] Executing task main:count with ID: e094778df6de450d811c4701cfff608f +1 Mississippi +2 Mississippi +3 Mississippi +[2025-11-11 22:45:19,589][taskiq.receiver.receiver][ERROR ][worker-0] Exception found while executing function: +Traceback (most recent call last): + File "P:\Scratches\taskiq-cancellation\.venv\lib\site-packages\taskiq\receiver\receiver.py", line 254, in run_task + returned = await target_future + File "P:\Scratches\taskiq-cancellation\src\taskiq_cancellation\abc\backend.py", line 190, in wrapper + return await level_handler(*args, **kwargs) + File "P:\Scratches\taskiq-cancellation\src\taskiq_cancellation\cancellation_handlers\level.py", line 104, in __call__ + raise TaskCancellationException() +taskiq_cancellation.exceptions.TaskCancellationException +``` diff --git a/examples/counter/main.py b/examples/counter/main.py index 5facc63..61f9789 100644 --- a/examples/counter/main.py +++ b/examples/counter/main.py @@ -1,12 +1,12 @@ import asyncio from taskiq_redis import PubSubBroker, RedisAsyncResultBackend -from taskiq_cancellation.integrations.redis import RedisCancellationBackend +from taskiq_cancellation.backends.redis import RedisCancellationBackend url = "redis://localhost" broker = PubSubBroker(url).with_result_backend(RedisAsyncResultBackend(url)) -cancellation_backend = RedisCancellationBackend(url) +cancellation_backend = RedisCancellationBackend(url).with_broker(broker) @broker.task @@ -20,11 +20,14 @@ async def count(up_to: int): async def main(): await broker.startup() + print("Sending task and waiting 5 seconds...") task = await count.kiq(5) await asyncio.sleep(5) + print("Sending task and waiting 2.5 seconds...") task = await count.kiq(5) await asyncio.sleep(2.5) + print("Canceling task...") await cancellation_backend.cancel(task.task_id) await broker.shutdown() diff --git a/imgs/backend_scheme.png b/imgs/backend_scheme.png new file mode 100644 index 0000000..30fb5cb Binary files /dev/null and b/imgs/backend_scheme.png differ diff --git a/imgs/header.png b/imgs/header.png new file mode 100644 index 0000000..6956864 Binary files /dev/null and b/imgs/header.png differ diff --git a/pyproject.toml b/pyproject.toml index 5c1c6cf..f55868d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,31 +1,32 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "taskiq-cancellation" dynamic = ["version"] -description = 'Task cancellation mechanism for taskiq' +description = 'Task cancellation for taskiq' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "MIT" keywords = ["taskiq", "cancellation"] authors = [{ name = "Alexander Starikov", email = "acherryjam@gmail.com" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["taskiq"] +dependencies = [ + "async-timeout>=5.0.1 ; python_full_version < '3.11'", + "taskiq", + "typing-extensions>=4.15.0 ; python_full_version < '3.11'", +] [project.optional-dependencies] -redis = ["redis~=3.0"] +redis = ["redis~=6.0"] aiopika = ["aio_pika"] [project.urls] @@ -33,6 +34,10 @@ Documentation = "https://github.com/ACherryJam/taskiq-cancellation#readme" Issues = "https://github.com/ACherryJam/taskiq-cancellation/issues" Source = "https://github.com/ACherryJam/taskiq-cancellation" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [tool.hatch.version] path = "src/taskiq_cancellation/__about__.py" @@ -41,6 +46,21 @@ extra-dependencies = ["mypy>=1.0.0"] [tool.hatch.envs.types.scripts] check = "mypy --install-types --non-interactive {args:src/taskiq_cancellation tests}" +[tool.hatch.build.targets.sdist] +only-include = ["src/taskiq_cancellation"] + +[tool.hatch.build.targets.wheel] +packages = ["src/taskiq_cancellation"] + +[tool.mypy] +ignore_missing_imports = true +exclude = [ + # Added so that "mypy ." would work + "examples", + # Contains Python 3.11+ code, has to be excluded. Runtime checks don't work. + "src/taskiq_cancellation/cancellation_handlers/edge_3_11.py", +] + [tool.coverage.run] source_pkgs = ["taskiq_cancellation", "tests"] branch = true @@ -56,3 +76,22 @@ tests = ["tests", "*/taskiq-cancellation/tests"] [tool.coverage.report] exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[dependency-groups] +dev = [ + "mypy>=1.14.1", + "pytest>=8.3.5", + "pytest-asyncio>=0.24.0", + "ruff>=0.14.4", +] + +[tool.ruff] +# Edge cancellation is currently Python 3.11+ which causes linter to freak out +# For such versioning cases tests must be written +target-version = "py314" + +[[tool.uv.index]] +name = "testpypi" +url = "https://test.pypi.org/simple/" +publish-url = "https://test.pypi.org/legacy/" +explicit = true diff --git a/src/taskiq_cancellation/__about__.py b/src/taskiq_cancellation/__about__.py index d3ec452..f102a9c 100644 --- a/src/taskiq_cancellation/__about__.py +++ b/src/taskiq_cancellation/__about__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.0.1" diff --git a/src/taskiq_cancellation/__init__.py b/src/taskiq_cancellation/__init__.py index 4ad9e80..914f0b2 100644 --- a/src/taskiq_cancellation/__init__.py +++ b/src/taskiq_cancellation/__init__.py @@ -1,8 +1,12 @@ -from .abc import * -from .modular import ModularCancellationBackend +from .abc import CancellationBackend +from .backends.modular import ModularCancellationBackend +from .backends.in_memory import InMemoryCancellationBackend +from .cancellation_handlers.cancellation_type import CancellationType +from .exceptions import TaskCancellationException __all__ = [ - "ModularCancellationBackend" + "CancellationBackend", "ModularCancellationBackend", + "InMemoryCancellationBackend", "CancellationType", "TaskCancellationException" ] diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py index cc8b52c..45239e2 100644 --- a/src/taskiq_cancellation/abc/__init__.py +++ b/src/taskiq_cancellation/abc/__init__.py @@ -1,10 +1,12 @@ from .backend import CancellationBackend from .notifier import CancellationNotifier from .state_holder import CancellationStateHolder +from .started_listening_event import StartedListeningEvent __all__ = [ "CancellationBackend", "CancellationNotifier", - "CancellationStateHolder" + "CancellationStateHolder", + "StartedListeningEvent", ] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 4f5ff15..6f20b5e 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,98 +1,213 @@ import abc -import asyncio -import traceback -from typing import Callable, Annotated, ParamSpec, TypeVar, Awaitable +import sys +import inspect +from typing import Callable, Annotated, overload, Optional, cast, Union -import anyio -from anyio.abc import TaskStatus -from taskiq import Context, TaskiqDepends +if sys.version_info >= (3, 11): + from typing import Self, ParamSpec, TypeVar +else: + from typing_extensions import Self, ParamSpec, TypeVar -from ..utils import combines -from ..exceptions import TaskCancellationException +from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState +from taskiq_cancellation.utils import combines +from taskiq_cancellation.cancellation_handlers import ( + CancellationType, + LevelCancellationHandler, + EdgeCancellationHandler, +) -P = ParamSpec("P") -R = TypeVar("R") +from .started_listening_event import StartedListeningEvent + + +Params = ParamSpec("Params") +Result = TypeVar("Result") class CancellationBackend(abc.ABC): - @abc.abstractmethod + """ + Base class for cancellation backend + """ + + def __init__(self) -> None: + super().__init__() + + self.broker: Union[AsyncBroker, None] = None + + @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: + """ + Checks if a task with task id of *task_id* is set to be cancelled + + :param task_id: task id to check + :type task_id: str + :returns: task cancellation state + :rtype: bool + """ pass @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Cancels a task with task id of *task_id* + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: + """ + Listens for cancellation messages and raises :ref:`TaskCancellationException` when + receives :ref:`CancellationMessage` with same id as *task_id*. + + This function is used in :func:`cancellable` decorator. + Call `started_listening_task_status.started()` when the listener is ready + to receive messages. + + :param task_id: id of task that will be listened for + :type task_id: str + :param started_listening_task_status: + :type started_listening_task_status: anyio.abc.TaskStatus + """ pass - - def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - # Executor type depends on receiver configuration which we - # can't access in any way - if not asyncio.iscoroutinefunction(task): - raise ValueError("Can't cancel synchronous function") - - @combines(task) - async def wrapper( - *args, - __taskiq_context: Annotated[Context, TaskiqDepends()], - **kwargs - ): - task_id = __taskiq_context.message.task_id - result = None - - listener_exception: Exception | None = None - task_exception: Exception | None = None - cancelled_by_request: bool = False - - async with anyio.create_task_group() as group: - async def listen_for_cancellation( - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED - ): - nonlocal listener_exception, cancelled_by_request - - try: - await self.listen_for_cancellation(task_id, task_status) - except TaskCancellationException: - cancelled_by_request = True - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - listener_exception = e - finally: - group.cancel_scope.cancel() - - async def call_task(): - nonlocal result, task_exception - - try: - result = await task(*args, **kwargs) - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - task_exception = e - finally: - group.cancel_scope.cancel() - - # Listen before checking for cancellation in database so - # the message won't get lost in non-persistent queues - await group.start(listen_for_cancellation) - if await self.is_cancelled(task_id): - cancelled_by_request = True - group.cancel_scope.cancel() - else: - group.start_soon(call_task) - - if task_exception is not None: - raise task_exception - elif cancelled_by_request: - raise TaskCancellationException() - elif listener_exception is not None: - raise listener_exception - else: - return result - return wrapper + + async def startup(self) -> None: + """ + Starts up cancellation backend + + Triggered only if backend has a broker set. To set a broker use :func:`with_broker`. + """ + pass + + async def shutdown(self) -> None: + """Shuts down cancellation backend + + Triggered only if backend has a broker set. To set a broker use :ref:`with_broker`. + """ + pass + + def with_broker(self, broker: AsyncBroker) -> Self: + """ + Set a broker and return updated cancellation backend + + Sets up startup and shutdown event handlers for backend's startup + and shutdown methods respectfully + + :param broker: new broker + :type broker: AsyncBroker + :returns: self + """ + if self.broker is not None: + self.broker.event_handlers[TaskiqEvents.CLIENT_STARTUP].remove( + self._broker_startup_handler + ) + self.broker.event_handlers[TaskiqEvents.WORKER_STARTUP].remove( + self._broker_startup_handler + ) + self.broker.event_handlers[TaskiqEvents.CLIENT_SHUTDOWN].remove( + self._broker_shutdown_handler + ) + self.broker.event_handlers[TaskiqEvents.WORKER_SHUTDOWN].remove( + self._broker_shutdown_handler + ) + + self.broker = broker + self.broker.add_event_handler( + TaskiqEvents.CLIENT_STARTUP, self._broker_startup_handler + ) + self.broker.add_event_handler( + TaskiqEvents.WORKER_STARTUP, self._broker_startup_handler + ) + self.broker.add_event_handler( + TaskiqEvents.CLIENT_SHUTDOWN, self._broker_shutdown_handler + ) + self.broker.add_event_handler( + TaskiqEvents.WORKER_SHUTDOWN, self._broker_shutdown_handler + ) + + return self + + @overload + def cancellable( + self, cancellation_type: Callable[Params, Result] + ) -> Callable[Params, Result]: + pass + + @overload + def cancellable( + self, cancellation_type: Optional[CancellationType] = None + ) -> Callable[[Callable[Params, Result]], Callable[Params, Result]]: + pass + + def cancellable(self, cancellation_type=None): + """ + Decorator that makes funcion cancellable + + This decorator makes a new function that creates two tasks in :ref:`anyio.TaskGroup`: + 1. Cancellation message listener (uses :ref:`listen_for_cancellation`) + 2. Wrapped function + + - Returns function's result/exception if it finishes successfully/unsuccessfully + - Raises :ref:`TaskCancellationException` if listener task receives cancellation message + - If listener task raises an exception, task is cancelled and exception is propogated upwards + + :param cancellation_type: type of cancellation used + :type cancellation_type: CancellationType + :returns: Cancellable task function + """ + + defaults = {"cancellation_type": CancellationType.LEVEL} + + def make_decorator(cancellation_type: CancellationType): + def decorator( + task: Callable[Params, Result], / + ) -> Callable[Params, Result]: + # Executor type depends on receiver configuration which we can't accessed in any way + if not inspect.iscoroutinefunction(task): + raise ValueError("Can't cancel synchronous function") + + @combines(task) + async def wrapper( + *args, + __taskiq_context: Annotated[Context, TaskiqDepends(Context)] = None, # type: ignore + **kwargs, + ) -> Result: + if __taskiq_context is None: + # Ran the function directly, without kiq + return await task(*args, **kwargs) + + task_id = __taskiq_context.message.task_id + + if cancellation_type is CancellationType.EDGE: + edge_handler = EdgeCancellationHandler(self, task, task_id) + return await edge_handler(*args, **kwargs) + elif cancellation_type is CancellationType.LEVEL: + level_handler = LevelCancellationHandler(self, task, task_id) + return await level_handler(*args, **kwargs) + else: + raise ValueError( + f"Unknown cancellation type: {cancellation_type!r}" + ) + + # Wrapper adds a key-word only param with default value + casted_wrapper = cast(Callable[Params, Result], wrapper) + return casted_wrapper + + return decorator + + if callable(cancellation_type): + return make_decorator(**defaults)(cancellation_type) + else: + return make_decorator( + cancellation_type=cancellation_type or defaults["cancellation_type"] + ) + + async def _broker_startup_handler(self, _: TaskiqState) -> None: + await self.startup() + + async def _broker_shutdown_handler(self, _: TaskiqState) -> None: + await self.shutdown() diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 0faaa91..60b1d14 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -1,20 +1,67 @@ +import sys import abc -from anyio.abc import TaskStatus +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + from taskiq.abc.serializer import TaskiqSerializer from taskiq.serializers import JSONSerializer +from .started_listening_event import StartedListeningEvent + class CancellationNotifier(abc.ABC): - def __init__(self, serializer: TaskiqSerializer = JSONSerializer()): - self.serializer = serializer - + """Receives cancellation messages and notifies listeners of these messages""" + + def __init__(self): + self.serializer = JSONSerializer() + + async def startup(self) -> None: + """Starts up cancellation notifier""" + pass + + async def shutdown(self) -> None: + """Shuts down cancellation notifier""" + pass + @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Sends a cancellation message of a task with task id of *task_id* + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: + """ + Listens for cancellation messages and raises :ref:`TaskCancellationException` when + receives :ref:`CancellationMessage` with same id as *task_id*. + + This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. + Call `started_listening_event.set()` when the listener is ready to receive messages. + + :param task_id: id of task that will be listened for + :type task_id: str + :param started_listening_event: "listener started listening" confirmation event + :type started_listening_event: StartedListeningEvent + """ pass + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.serializer = serializer + + return self diff --git a/src/taskiq_cancellation/abc/started_listening_event.py b/src/taskiq_cancellation/abc/started_listening_event.py new file mode 100644 index 0000000..4e43acf --- /dev/null +++ b/src/taskiq_cancellation/abc/started_listening_event.py @@ -0,0 +1,22 @@ +import abc + + +class StartedListeningEvent(abc.ABC): + """ + A confirmation event for listeners to mark that they started listening to messages. API is + similar to :ref:`asyncio.Event`. + + This is needed for different cancellation types: + - Level cancellation uses :ref:`anyio.abc.TaskStatus` + - Edge cancellation uses :ref:`asyncio.Event` + """ + + @abc.abstractmethod + async def set(self): + """Sets the event""" + pass + + @abc.abstractmethod + async def wait(self): + """Waits for the event to be set""" + pass diff --git a/src/taskiq_cancellation/abc/state_holder.py b/src/taskiq_cancellation/abc/state_holder.py index 45df346..bb2a707 100644 --- a/src/taskiq_cancellation/abc/state_holder.py +++ b/src/taskiq_cancellation/abc/state_holder.py @@ -2,10 +2,34 @@ class CancellationStateHolder(abc.ABC): + """Holds cancellation state of Taskiq tasks""" + @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Sets a state of task with task id of *task_id* to be cancelled + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: + """ + Checks if a task with task id of *task_id* is set to be cancelled + + :param task_id: task id to check + :type task_id: str + :returns: task cancellation state + :rtype: bool + """ + pass + + async def startup(self) -> None: + """Starts up cancellation state holder""" + pass + + async def shutdown(self) -> None: + """Shuts down cancellation state holder""" pass diff --git a/src/taskiq_cancellation/backends/__init__.py b/src/taskiq_cancellation/backends/__init__.py new file mode 100644 index 0000000..5c65c46 --- /dev/null +++ b/src/taskiq_cancellation/backends/__init__.py @@ -0,0 +1,8 @@ +from .modular import ModularCancellationBackend +from .in_memory import InMemoryCancellationBackend + + +__all__ = [ + "ModularCancellationBackend", + "InMemoryCancellationBackend" +] diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py new file mode 100644 index 0000000..795ac6e --- /dev/null +++ b/src/taskiq_cancellation/backends/in_memory.py @@ -0,0 +1,18 @@ +from taskiq_cancellation.notifiers.in_memory import InMemoryCancellationNotifier +from taskiq_cancellation.state_holders.in_memory import InMemoryCancellationStateHolder + +from .modular import ModularCancellationBackend + + +class InMemoryCancellationBackend(ModularCancellationBackend): + """ + Cancellation backend that stores state and notifications in memory + + Useful for testing purposes + """ + + def __init__(self, **kwargs): + super().__init__( + state_holder=InMemoryCancellationStateHolder(**kwargs), + notifier=InMemoryCancellationNotifier(**kwargs), + ) diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py new file mode 100644 index 0000000..c65c00c --- /dev/null +++ b/src/taskiq_cancellation/backends/modular.py @@ -0,0 +1,70 @@ +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from taskiq.abc.serializer import TaskiqSerializer + +from taskiq_cancellation.abc import ( + CancellationBackend, + CancellationNotifier, + CancellationStateHolder, + StartedListeningEvent, +) + +import anyio + + +class ModularCancellationBackend(CancellationBackend): + """ + Modular cancellation backend made up of :class:`CancellationStateHolder` + and :class:`CancellationNotifier` + + - `CancellationStateHolder` stores cancellation state and blocks the task from being run. + - `CancellationNotifier` receives cancellation messages and cancels already running tasks. + """ + + def __init__( + self, state_holder: CancellationStateHolder, notifier: CancellationNotifier + ): + super().__init__() + + self.notifier: CancellationNotifier = notifier + self.state_holder: CancellationStateHolder = state_holder + + async def is_cancelled(self, task_id: str) -> bool: + return await self.state_holder.is_cancelled(task_id) + + async def cancel(self, task_id: str): + async with anyio.create_task_group() as group: + group.start_soon(self.state_holder.cancel, task_id) + group.start_soon(self.notifier.cancel, task_id) + + async def listen_for_cancellation( + self, task_id: str, started_listening_event: StartedListeningEvent + ): + await self.notifier.listen_for_cancellation(task_id, started_listening_event) + + async def startup(self) -> None: + await super().startup() + await self.state_holder.startup() + await self.notifier.startup() + + async def shutdown(self) -> None: + await super().shutdown() + await self.state_holder.shutdown() + await self.notifier.shutdown() + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.notifier = self.notifier.with_serializer(serializer) + + return self diff --git a/src/taskiq_cancellation/backends/redis.py b/src/taskiq_cancellation/backends/redis.py new file mode 100644 index 0000000..3870e32 --- /dev/null +++ b/src/taskiq_cancellation/backends/redis.py @@ -0,0 +1,12 @@ +from taskiq_cancellation.backends.modular import ModularCancellationBackend + +from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder + + +class RedisCancellationBackend(ModularCancellationBackend): + def __init__(self, url: str, **kwargs) -> None: + super().__init__( + RedisCancellationStateHolder(url, **kwargs), + PubSubCancellationNotifier(url, **kwargs), + ) diff --git a/src/taskiq_cancellation/cancellation_handlers/__init__.py b/src/taskiq_cancellation/cancellation_handlers/__init__.py new file mode 100644 index 0000000..a228c8d --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/__init__.py @@ -0,0 +1,12 @@ +import sys + +from .cancellation_type import CancellationType +from .level import LevelCancellationHandler + +if sys.version_info >= (3, 11): + from .edge_3_11 import EdgeCancellationHandler +else: + from .edge_non_supported import EdgeCancellationHandler + + +__all__ = ["CancellationType", "LevelCancellationHandler", "EdgeCancellationHandler"] diff --git a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py new file mode 100644 index 0000000..9714e67 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py @@ -0,0 +1,8 @@ +import enum + + +class CancellationType(str, enum.Enum): + """Type of cancellation used by the backend""" + + EDGE = "edge" + LEVEL = "level" diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py new file mode 100644 index 0000000..4bf66b5 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py @@ -0,0 +1,136 @@ +# FIXME: bunch of issues because of Python 3.11+ exclusivity +# +# Edge cancellation handler is using asyncio.TaskGroup which was introduced in 3.11 and +# uses expect-star syntax that was also introduced in the same version +# - mypy has to ignore this file because it can't finish static parsing +# - ruff's python version has to be set at 3.11+ so it wouldn't complain +# +# I'm not sure how to mitigate these issues. Maybe this can be put in a separate module somehow +# and then integrated? Maybe this can be rewritten to not use TaskGroup (probably easier to do)? + +import sys +import logging +import asyncio +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.utils import StopTaskGroupException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. That means :ref:`asyncio.CancelledError` is + raised only once for the task. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + + Currently is supported in Python 3.11+ due to using :ref:`asyncio.TaskGroup`. + """ + + class ListeningEvent(StartedListeningEvent): + def __init__(self) -> None: + self.event = asyncio.Event() + + async def set(self): + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self.event.set) + + async def wait(self): + await self.event.wait() + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Result = None # type: ignore + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async def listen_for_cancellation(event: StartedListeningEvent): + nonlocal listener_exception, cancelled_by_request + + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + raise + except asyncio.CancelledError: + raise + except Exception as e: + listener_exception = e + raise + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except asyncio.CancelledError: + raise + except Exception as e: + task_exception = e + raise + + try: + async with asyncio.TaskGroup() as tg: + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + event = self.ListeningEvent() + tg.create_task(listen_for_cancellation(event)) + await event.wait() + + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + raise StopTaskGroupException() + + task_task = asyncio.create_task(call_task()) + await task_task + if not task_task.cancelled(): + raise StopTaskGroupException() + except* StopTaskGroupException: + pass + except* Exception as exc_group: + uncaught_exceptions = list( + filter( + lambda e: e == task_exception or e == listener_exception, + exc_group.exceptions, + ) + ) + + if uncaught_exceptions: + logging.log(logging.ERROR, "Uncaught exceptions in TaskGroup") + for e in uncaught_exceptions: + logging.exception(e) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py new file mode 100644 index 0000000..03d06fd --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py @@ -0,0 +1,34 @@ +import sys +from typing import Callable, TYPE_CHECKING, Coroutine, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. Currently is supported in Python 3.11+ due + to using :ref:`asyncio.TaskGroup`. + """ + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ) -> None: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py new file mode 100644 index 0000000..2cfe71a --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/level.py @@ -0,0 +1,110 @@ +import sys +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Union, cast, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +import anyio +from anyio.abc import TaskStatus + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class LevelCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses level cancellation provided by anyio. That means cancellation exception is raised + on every await in the coroutine. + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ + + class ListeningEvent(StartedListeningEvent): + def __init__(self, task_status: TaskStatus) -> None: + self.task_status = task_status + + async def set(self): + self.task_status.started() + + async def wait(self): + # Can ignore, won't execute further before task status is set + pass + + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Union[Result, None] = None + + listener_exception: Union[Exception, None] = None + task_exception: Union[Exception, None] = None + cancelled_by_request: bool = False + + async with anyio.create_task_group() as group: + + async def listen_for_cancellation( + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ): + nonlocal listener_exception, cancelled_by_request + + event = self.ListeningEvent(task_status) + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + listener_exception = e + finally: + group.cancel_scope.cancel() + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + task_exception = e + finally: + group.cancel_scope.cancel() + + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + await group.start(listen_for_cancellation) + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + group.cancel_scope.cancel() + else: + group.start_soon(call_task) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + # If the task is finished, it is definitely not None + result = cast(Result, result) + return result diff --git a/src/taskiq_cancellation/integrations/aiopika/__init__.py b/src/taskiq_cancellation/integrations/aiopika/__init__.py deleted file mode 100644 index 9d0488c..0000000 --- a/src/taskiq_cancellation/integrations/aiopika/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .notifier import AioPikaNotifier - - -__all__ = [ - "AioPikaNotifier" -] diff --git a/src/taskiq_cancellation/integrations/aiopika/notifier.py b/src/taskiq_cancellation/integrations/aiopika/notifier.py deleted file mode 100644 index b79ad78..0000000 --- a/src/taskiq_cancellation/integrations/aiopika/notifier.py +++ /dev/null @@ -1,65 +0,0 @@ -import time - -import aio_pika -from anyio.abc import TaskStatus -from taskiq.abc.serializer import TaskiqSerializer -from taskiq.serializers import JSONSerializer -from taskiq.compat import model_dump, model_validate - -from taskiq_cancellation.abc import CancellationNotifier -from taskiq_cancellation.exceptions import TaskCancellationException -from taskiq_cancellation.message import CancellationMessage - - -class AioPikaNotifier(CancellationNotifier): - EXCHANGE_NAME = "__taskiq_cancellation" - - def __init__(self, url: str, serializer: TaskiqSerializer = JSONSerializer()): - super().__init__(serializer) - - self.url: str = url - - async def cancel(self, task_id: str) -> None: - timestamp = time.time() - connection = await aio_pika.connect_robust( - self.url - ) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - - await exchange.publish( - aio_pika.Message(body=self.serializer.dumpb( - model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp)) - )), - routing_key="" - ) - - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus - ) -> None: - connection = await aio_pika.connect_robust( - self.url - ) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - queue = await channel.declare_queue(exclusive=True) - await queue.bind(exchange) - - started_listening_task_status.started() - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - cancellation_message = model_validate( - CancellationMessage, - self.serializer.loadb(message.body) - ) - - if cancellation_message.task_id == task_id: - raise TaskCancellationException() - await message.ack() diff --git a/src/taskiq_cancellation/integrations/redis/__init__.py b/src/taskiq_cancellation/integrations/redis/__init__.py deleted file mode 100644 index fe38b68..0000000 --- a/src/taskiq_cancellation/integrations/redis/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .notifier import PubSubCancellationNotifier -from .state_holder import RedisCancellationStateHolder -from .backend import RedisCancellationBackend - - -__all__ = [ - "PubSubCancellationNotifier", - "RedisCancellationStateHolder", - "RedisCancellationBackend" -] diff --git a/src/taskiq_cancellation/integrations/redis/backend.py b/src/taskiq_cancellation/integrations/redis/backend.py deleted file mode 100644 index ad038d1..0000000 --- a/src/taskiq_cancellation/integrations/redis/backend.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Type - -from taskiq_cancellation.modular import ModularCancellationBackend - -from .notifier import PubSubCancellationNotifier -from .state_holder import RedisCancellationStateHolder - - -class RedisCancellationBackend(ModularCancellationBackend): - def __init__( - self, - url: str, - state_holder: Type = RedisCancellationStateHolder, - notifier: Type = PubSubCancellationNotifier, - **connection_kwargs - ) -> None: - super().__init__( - state_holder(url, **connection_kwargs), - PubSubCancellationNotifier(url, **connection_kwargs) - ) diff --git a/src/taskiq_cancellation/integrations/redis/notifier.py b/src/taskiq_cancellation/integrations/redis/notifier.py deleted file mode 100644 index 5ed24fb..0000000 --- a/src/taskiq_cancellation/integrations/redis/notifier.py +++ /dev/null @@ -1,52 +0,0 @@ -import time - -from anyio.abc import TaskStatus -import redis.asyncio as redis -from taskiq.compat import model_dump, model_validate - -from taskiq_cancellation.abc import CancellationNotifier -from taskiq_cancellation.exceptions import TaskCancellationException -from taskiq_cancellation.message import CancellationMessage - - -class PubSubCancellationNotifier(CancellationNotifier): - CHANNEL_NAME = "__taskiq_cancellation_notifications" - - def __init__(self, url: str, **connection_kwargs) -> None: - super().__init__() - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) - - async def cancel(self, task_id: str) -> None: - timestamp = time.time() - - async with redis.Redis(connection_pool=self.connection_pool) as conn: - await conn.publish( - self.CHANNEL_NAME, - self.serializer.dumpb( - model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp)) - ) - ) - - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus - ) -> None: - async with redis.Redis(connection_pool=self.connection_pool) as conn: - pubsub = conn.pubsub() - await pubsub.subscribe(self.CHANNEL_NAME) - - started_listening_task_status.started() - - while True: - message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None) - - if message is None: - continue - if message["type"] != "message": - continue - - cancellation_message = model_validate( - CancellationMessage, - self.serializer.loadb(message['data']) - ) - if cancellation_message.task_id == task_id: - raise TaskCancellationException() diff --git a/src/taskiq_cancellation/modular.py b/src/taskiq_cancellation/modular.py deleted file mode 100644 index bead650..0000000 --- a/src/taskiq_cancellation/modular.py +++ /dev/null @@ -1,27 +0,0 @@ -from .abc import * - -import anyio -from anyio.abc import TaskStatus - - -class ModularCancellationBackend(CancellationBackend): - def __init__( - self, - state_holder: CancellationStateHolder, - notifier: CancellationNotifier - ): - self.notifier: CancellationNotifier = notifier - self.state_holder: CancellationStateHolder = state_holder - - async def is_cancelled(self, task_id: str) -> bool: - return await self.state_holder.is_cancelled(task_id) - - async def cancel(self, task_id: str): - async with anyio.create_task_group() as group: - group.start_soon(self.state_holder.cancel, task_id) - group.start_soon(self.notifier.cancel, task_id) - - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus[None] - ): - await self.notifier.listen_for_cancellation(task_id, started_listening_task_status) diff --git a/src/taskiq_cancellation/notifiers/__init__.py b/src/taskiq_cancellation/notifiers/__init__.py new file mode 100644 index 0000000..9a097d6 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/__init__.py @@ -0,0 +1,8 @@ +from .null import NullCancellationNotifier +from .in_memory import InMemoryCancellationNotifier + + +__all__ = [ + "NullCancellationNotifier", + "InMemoryCancellationNotifier" +] diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py new file mode 100644 index 0000000..2432e56 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -0,0 +1,76 @@ +import time +import asyncio + +import aio_pika +from taskiq.compat import model_dump, model_validate + +from taskiq_cancellation.message import CancellationMessage + +from .queue import QueueCancellationNotifier + + +class AioPikaNotifier(QueueCancellationNotifier): + """Notifier for RabbitMQ using aio-pika""" + + EXCHANGE_NAME = "__taskiq_cancellation" + + def __init__(self, url: str, **connection_kwargs): + """ + Creates AioPika notifier + + :param url: url to rabbitmq + :type url: str + :param connection_kwargs: arguments for :ref:`aio_pika.connect_robust` + """ + super().__init__() + + self.url: str = url + self.connection_kwargs = connection_kwargs + + async def cancel(self, task_id: str) -> None: + timestamp = time.time() + + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + + await exchange.publish( + aio_pika.Message( + body=self.serializer.dumpb( + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) + ) + ), + routing_key="", + ) + + async def _listen(self, started_listening: asyncio.Event): + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + queue = await channel.declare_queue(exclusive=True, auto_delete=True) + await queue.bind(exchange) + + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + cancellation_message = model_validate( + CancellationMessage, self.serializer.loadb(message.body) + ) + + for subscriber_queue in self.queues: + await subscriber_queue.put(cancellation_message) + await message.ack() diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py new file mode 100644 index 0000000..1dfe59f --- /dev/null +++ b/src/taskiq_cancellation/notifiers/in_memory.py @@ -0,0 +1,41 @@ +import time +import asyncio +from typing import Union + +from taskiq_cancellation.message import CancellationMessage + +from .queue import QueueCancellationNotifier + + +class InMemoryCancellationNotifier(QueueCancellationNotifier): + """In memory cancellation notifier used for testing""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + # In Python 3.9 queues must be created inside a running loop + # Source: https://stackoverflow.com/questions/53724665 + self.messages: Union[asyncio.Queue[CancellationMessage], None] = None + + async def cancel(self, task_id: str) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + + timestamp = time.time() + + await self.messages.put( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) + + async def _listen(self, started_listening: asyncio.Event) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + while True: + message = await self.messages.get() + + for queue in self.queues: + await queue.put(message) diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py new file mode 100644 index 0000000..bb19211 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/null.py @@ -0,0 +1,20 @@ +import asyncio + +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent + + +class NullCancellationNotifier(CancellationNotifier): + """ + \"Do nothing\" cancellation notifier + + May be useful if there's no need or ability to use an actual notifier + """ + + async def cancel(self, task_id: str) -> None: + pass + + async def listen_for_cancellation( + self, task_id: str, started_listening_event: StartedListeningEvent + ) -> None: + await started_listening_event.set() + await asyncio.sleep(float("+inf")) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py new file mode 100644 index 0000000..16ea174 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -0,0 +1,72 @@ +import abc +import weakref +import asyncio +from typing import Union + +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.message import CancellationMessage + + +class QueueCancellationNotifier(CancellationNotifier): + """ + A helper cancellation notifier that uses one listener to receive cancellation messages and + notifies listeners from `listen_for_cancellation` via `asyncio.Queue` + + Requires :func:`_listen` to be implemeted + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.listener_task: Union[asyncio.Task, None] = None + self.queues: weakref.WeakSet[asyncio.Queue[CancellationMessage]] = ( + weakref.WeakSet() + ) + """Set of subscribers' `asyncio.Queue`s to populate when message's received""" + + async def shutdown(self) -> None: + if self.listener_task is not None: + self.listener_task.cancel() + await asyncio.wait([self.listener_task]) + + async def listen_for_cancellation( + self, task_id: str, started_listening_event: StartedListeningEvent + ) -> None: + cancellations: asyncio.Queue[CancellationMessage] = asyncio.Queue() + + if self.listener_task is None: + await self._create_listener_task() + + await self._subscribe(cancellations) + await started_listening_event.set() + + while True: + cancellation_message = await cancellations.get() + + if cancellation_message.task_id == task_id: + raise TaskCancellationException() + + @abc.abstractmethod + async def _listen(self, started_listening: asyncio.Event) -> None: + """ + Listens for cancellation messages and put them into subscribers' `asyncio.Queue`s + + :param started_listening: event to be set when listener is ready to receive messages + :type started_listening: asyncio.Event + """ + pass + + async def _create_listener_task(self): + if self.listener_task is not None: + self.listener_task.cancel() + + started_listening = asyncio.Event() + self.listener_task = asyncio.create_task(self._listen(started_listening)) + await started_listening.wait() + + async def _subscribe(self, queue: asyncio.Queue[CancellationMessage]): + self.queues.add(queue) + + async def _unsubsribe(self, queue: asyncio.Queue[CancellationMessage]): + self.queues.remove(queue) diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py new file mode 100644 index 0000000..fe7746e --- /dev/null +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -0,0 +1,70 @@ +import time +import asyncio + +import redis.asyncio as redis +from taskiq.compat import model_dump, model_validate + +from taskiq_cancellation.message import CancellationMessage + +from .queue import QueueCancellationNotifier + + +class PubSubCancellationNotifier(QueueCancellationNotifier): + """Cancellation notifier using Redis pub/sub""" + + CHANNEL_NAME = "__taskiq_cancellation_notifications" + + def __init__(self, url: str, **connection_kwargs) -> None: + """ + Creates AioPika notifier + + :param url: url to redis + :type url: str + :param connection_kwargs: arguments for :ref:`redis.BlockingConnectionPool.from_url` + """ + super().__init__() + + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) + + async def cancel(self, task_id: str) -> None: + timestamp = time.time() + + async with redis.Redis(connection_pool=self.connection_pool) as conn: + await conn.publish( + self.CHANNEL_NAME, + self.serializer.dumpb( + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) + ), + ) + + async def _listen(self, started_listening: asyncio.Event): + async with redis.Redis(connection_pool=self.connection_pool) as conn: + pubsub = conn.pubsub() + await pubsub.subscribe(self.CHANNEL_NAME) + + # started_listening.set() + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + while True: + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=None + ) + + if message is None: + continue + if message["type"] != "message": + continue + + cancellation_message = model_validate( + CancellationMessage, self.serializer.loadb(message["data"]) + ) + for queue in self.queues: + await queue.put(cancellation_message) + + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + \ No newline at end of file diff --git a/src/taskiq_cancellation/integrations/__init__.py b/src/taskiq_cancellation/py.typed similarity index 100% rename from src/taskiq_cancellation/integrations/__init__.py rename to src/taskiq_cancellation/py.typed diff --git a/src/taskiq_cancellation/state_holders/__init__.py b/src/taskiq_cancellation/state_holders/__init__.py new file mode 100644 index 0000000..9296f10 --- /dev/null +++ b/src/taskiq_cancellation/state_holders/__init__.py @@ -0,0 +1,8 @@ +from .in_memory import InMemoryCancellationStateHolder +from .null import NullCancellationStateHolder + + +__all__ = [ + "InMemoryCancellationStateHolder", + "NullCancellationStateHolder" +] diff --git a/src/taskiq_cancellation/state_holders/in_memory.py b/src/taskiq_cancellation/state_holders/in_memory.py new file mode 100644 index 0000000..b63277e --- /dev/null +++ b/src/taskiq_cancellation/state_holders/in_memory.py @@ -0,0 +1,16 @@ +from taskiq_cancellation.abc import CancellationStateHolder + + +class InMemoryCancellationStateHolder(CancellationStateHolder): + """In memory cancellation state holder used for testing""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.state_holder: dict[str, bool] = {} + + async def cancel(self, task_id: str) -> None: + self.state_holder[task_id] = True + + async def is_cancelled(self, task_id: str) -> bool: + return self.state_holder.get(task_id, False) diff --git a/src/taskiq_cancellation/state_holders/null.py b/src/taskiq_cancellation/state_holders/null.py new file mode 100644 index 0000000..b24462d --- /dev/null +++ b/src/taskiq_cancellation/state_holders/null.py @@ -0,0 +1,15 @@ +from taskiq_cancellation.abc import CancellationStateHolder + + +class NullCancellationStateHolder(CancellationStateHolder): + """ + \"Do nothing\" cancellation state holder + + May be useful if there's no need or ability to use an actual state holder + """ + + async def cancel(self, task_id: str) -> None: + pass + + async def is_cancelled(self, task_id: str) -> bool: + return False diff --git a/src/taskiq_cancellation/integrations/redis/state_holder.py b/src/taskiq_cancellation/state_holders/redis.py similarity index 76% rename from src/taskiq_cancellation/integrations/redis/state_holder.py rename to src/taskiq_cancellation/state_holders/redis.py index 4ca397c..be86f08 100644 --- a/src/taskiq_cancellation/integrations/redis/state_holder.py +++ b/src/taskiq_cancellation/state_holders/redis.py @@ -4,9 +4,9 @@ class RedisCancellationStateHolder(CancellationStateHolder): - def __init__(self, url: str, **connection_kwargs) -> None: - super().__init__() - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) + def __init__(self, url: str, **kwargs) -> None: + super().__init__(**kwargs) + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs) async def cancel(self, task_id: str) -> None: async with redis.Redis(connection_pool=self.connection_pool) as conn: @@ -17,5 +17,10 @@ async def is_cancelled(self, task_id: str) -> bool: response = await conn.get(self._task_key(task_id)) return bool(response) + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + def _task_key(self, task_id: str) -> str: return f"__cancellation_status_{task_id}" + diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index b1cda03..a128266 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -5,14 +5,16 @@ from collections import OrderedDict -def combines(wrapped): +def combines( + wrapped: typing.Callable, add_var_parameters: bool = False +) -> typing.Callable: """ Combines wrapped and wrapper functions signatures and type hints In cases of parameter collision wrapper parameters will be used Example: - ''' + ''' import inspect def decorator(func): @@ -20,41 +22,59 @@ def decorator(func): def wrapper(c: int, *args, **kwargs): return foo(*args, **kwargs) * c return wrapper - + @decorator def foo(a: int, b = "lol"): return b * a - + print(inspect.signature(foo)) # (a: int, c: int, b='lol', *args, **kwargs) ''' + + :param wrapped: function to be wrapped + :type wrapped: Callable + :param add_var_parameters: add *args and **kwargs from wrapper to new signature + :type add_var_parameters: bool """ wrapped_signature: inspect.Signature = inspect.signature(wrapped) wrapped_type_hints: typing.Dict[str, str] = typing.get_type_hints(wrapped) - + def decorator(wrapper): wrapper_signature = inspect.signature(wrapper) wrapper_type_hints = typing.get_type_hints(wrapper) for param_name in wrapped_signature.parameters.keys(): if param_name in wrapper_signature.parameters.keys(): - logging.warning(f"Parameter {param_name} will be overwritten by wrapper function") - - parameters = OrderedDict(wrapped_signature.parameters, **wrapper_signature.parameters) + logging.warning( + f"Parameter {param_name} will be overwritten by wrapper function" + ) + + wrapper_parameters = OrderedDict() + for name, parameter in wrapper_signature.parameters.items(): + if not add_var_parameters: + if any( + ( + parameter.kind is inspect.Parameter.VAR_POSITIONAL, + parameter.kind is inspect.Parameter.VAR_KEYWORD, + ) + ): + continue + wrapper_parameters[name] = parameter + + parameters = OrderedDict(wrapped_signature.parameters, **wrapper_parameters) parameters = sorted( parameters.values(), - key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0) + key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0), ) new_return_annotation: inspect.Signature if wrapper_signature.return_annotation is not None: - new_return_annotation = wrapper_signature.return_annotation + new_return_annotation = wrapper_signature.return_annotation else: - new_return_annotation = wrapped_signature.return_annotation + new_return_annotation = wrapped_signature.return_annotation new_signature = inspect.Signature( - parameters=parameters, - return_annotation=new_return_annotation + parameters=parameters, return_annotation=new_return_annotation ) new_annotations = dict(wrapped_type_hints, **wrapper_type_hints) @@ -63,9 +83,12 @@ def decorator(wrapper): wrapper.__signature__ = new_signature return wrapper + return decorator -__all__ = [ - "combines" -] +class StopTaskGroupException(Exception): + pass + + +__all__ = ["combines"] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/__init__.py b/tests/integration/aiopika/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/conftest.py b/tests/integration/aiopika/conftest.py new file mode 100644 index 0000000..53e27dc --- /dev/null +++ b/tests/integration/aiopika/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def rabbitmq_url(): + return os.environ.get("TEST_RABBITMQ_URL", "amqp://guest:guest@localhost:5672") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/aiopika/test_notifier.py b/tests/integration/aiopika/test_notifier.py new file mode 100644 index 0000000..11e6553 --- /dev/null +++ b/tests/integration/aiopika/test_notifier.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.aiopika import AioPikaNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def notifier(rabbitmq_url): + return AioPikaNotifier(url=rabbitmq_url) + + +@pytest.mark.asyncio +async def test_cancellation(notifier: AioPikaNotifier): + await run_notifier_cancellation_test(notifier) diff --git a/tests/integration/common/cancellations.py b/tests/integration/common/cancellations.py new file mode 100644 index 0000000..e495604 --- /dev/null +++ b/tests/integration/common/cancellations.py @@ -0,0 +1,94 @@ +import sys +import uuid +import pytest +import asyncio + +from taskiq import InMemoryBroker + +from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder +from taskiq_cancellation.backends.modular import ModularCancellationBackend +from taskiq_cancellation.notifiers.null import NullCancellationNotifier +from taskiq_cancellation.state_holders.null import NullCancellationStateHolder +from taskiq_cancellation.exceptions import TaskCancellationException + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + + + +async def run_backend_cancellation_test(backend: CancellationBackend): + broker = InMemoryBroker() + backend = backend.with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + pass + + await broker.startup() + + task = await test_task.kiq() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_notifier_cancellation_test(notifier: CancellationNotifier): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + NullCancellationStateHolder(), + notifier + ).with_broker(broker) + + task_started = asyncio.Event() + + @broker.task + @backend.cancellable + async def test_task(): + task_started.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + async with timeout(1): + await task_started.wait() + + await backend.cancel(task.task_id) + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_state_holder_cancellation_test(state_holder: CancellationStateHolder): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + state_holder, + NullCancellationNotifier() + ).with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + # This task is supposed to never start + assert False + + await broker.startup() + + task_id = str(uuid.uuid4()) + await backend.cancel(task_id) + task = await test_task.kicker().with_task_id(task_id).kiq() + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() diff --git a/tests/integration/redis/__init__.py b/tests/integration/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/redis/conftest.py b/tests/integration/redis/conftest.py new file mode 100644 index 0000000..1823f42 --- /dev/null +++ b/tests/integration/redis/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def redis_url(): + return os.environ.get("TEST_REDIS_URL", "redis://localhost:6379") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/redis/test_backend.py b/tests/integration/redis/test_backend.py new file mode 100644 index 0000000..a5bde89 --- /dev/null +++ b/tests/integration/redis/test_backend.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.backends.redis import RedisCancellationBackend + +from ..common.cancellations import run_backend_cancellation_test + + +@pytest.fixture +def redis_backend(redis_url, broker): + return RedisCancellationBackend(url=redis_url).with_broker(broker) + + +@pytest.mark.asyncio +async def test_cancellation(redis_backend: RedisCancellationBackend): + await run_backend_cancellation_test(redis_backend) diff --git a/tests/integration/redis/test_pubsub.py b/tests/integration/redis/test_pubsub.py new file mode 100644 index 0000000..b49720c --- /dev/null +++ b/tests/integration/redis/test_pubsub.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def pubsub_notifier(redis_url): + return PubSubCancellationNotifier(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(pubsub_notifier: PubSubCancellationNotifier): + await run_notifier_cancellation_test(pubsub_notifier) diff --git a/tests/integration/redis/test_state_holder.py b/tests/integration/redis/test_state_holder.py new file mode 100644 index 0000000..f8be37b --- /dev/null +++ b/tests/integration/redis/test_state_holder.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder + +from ..common.cancellations import run_state_holder_cancellation_test + + +@pytest.fixture +def state_holder(redis_url): + return RedisCancellationStateHolder(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(state_holder: RedisCancellationStateHolder): + await run_state_holder_cancellation_test(state_holder) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py new file mode 100644 index 0000000..459492d --- /dev/null +++ b/tests/unit/test_backend.py @@ -0,0 +1,50 @@ +import pytest +import inspect + +from taskiq_cancellation.abc.backend import CancellationBackend +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend + + +@pytest.fixture +def backend(): + return InMemoryCancellationBackend() + + +@pytest.mark.asyncio +async def test_decorator_without_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works without parentesis""" + + @backend.cancellable + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + + +@pytest.mark.asyncio +async def test_decorator_with_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works with parentesis""" + + @backend.cancellable() + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + + +@pytest.mark.asyncio +async def test_decorator_with_sync_function(backend: CancellationBackend): + """ + Tests that cancellable decorator doesn't work with synchronous functions + + To launch a synchronous function we need to know how to do it and only Receiver knows that + """ + + with pytest.raises(ValueError): + @backend.cancellable + def test_task(): + pass diff --git a/tests/unit/test_cancellation.py b/tests/unit/test_cancellation.py new file mode 100644 index 0000000..892e988 --- /dev/null +++ b/tests/unit/test_cancellation.py @@ -0,0 +1,237 @@ +import sys +import pytest +import asyncio + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + +import anyio +from taskiq import AsyncBroker, InMemoryBroker + +from taskiq_cancellation.abc.backend import CancellationBackend, CancellationType +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend +from taskiq_cancellation.exceptions import TaskCancellationException + + +@pytest.fixture +def broker(): + return InMemoryBroker() + + +@pytest.fixture +def backend(broker: AsyncBroker): + return InMemoryCancellationBackend().with_broker(broker) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) +) +async def test_task_direct_call( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, +): + """Tests that cancellable function can be called directly""" + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + return True + + result = await test_task() + assert result + + +class TestTaskSuccess: + types = [CancellationType.LEVEL] + if sys.version_info >= (3, 11): + types.append(CancellationType.EDGE) + + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_success( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully finish""" + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False + + await broker.shutdown() + + +class TestTaskCancellation: + types = [CancellationType.LEVEL] + if sys.version_info > (3, 11): + types.append(CancellationType.EDGE) + + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_cancellation( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully cancel""" + + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + with pytest.raises(anyio.get_cancelled_exc_class()): + started_event.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +class TestEdgeCancellation: + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) + async def test_non_supported( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """Tests that edge cancellation raises NotImplementedError in Python <3.11""" + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + + with pytest.raises(NotImplementedError): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) + async def test_cancellation_interception( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that tasks can capture task cancellation + + asyncio raises asyncio.CancelledError only once. Task can intercept that to do cleanup, + but also can just ignore the cancellation request. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + """ + + cancelled_for_second_time = False + task_started = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + task_started.set() + await asyncio.sleep(0.5) + except asyncio.CancelledError: + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await task_started.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is False + + await broker.shutdown() + + +class TestLevelCancellation: + @pytest.mark.asyncio + async def test_repeated_cancellation( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that task will have multiple cancellation exceptions + + anyio raises cancellation exception on every await + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ + + cancelled_for_second_time = False + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.LEVEL) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + started_event.set() + await asyncio.sleep(1) + except anyio.get_cancelled_exc_class(): + # anyio cancels on any await after scope's cancellation + try: + await asyncio.sleep(0) + except anyio.get_cancelled_exc_class(): + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is True + + await broker.shutdown()