diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 0000000..cebff2b
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,33 @@
+name: Lint
+
+on:
+ pull_request:
+ branches: [develop, main]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v5
+
+ - name: Setup Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: 3.14
+
+ - name: Setup uv
+ uses: astral-sh/setup-uv@v7
+
+ - name: Create virtual environment
+ run: uv venv .venv && source .venv/bin/activate
+
+ - name: Install modules
+ run: uv sync
+
+ - name: Check code style
+ run: uv run ruff check
+
+ - name: Check static typing
+ run: uv run mypy .
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
new file mode 100644
index 0000000..9473ce2
--- /dev/null
+++ b/.github/workflows/release.yaml
@@ -0,0 +1,28 @@
+name: Release on PyPI
+
+on:
+ release:
+ types:
+ - released
+
+jobs:
+ publish:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v5
+
+ - name: Setup Python 3.14
+ uses: actions/setup-python@v6
+ with:
+ python-version: 3.14
+
+ - name: Setup uv
+ uses: astral-sh/setup-uv@v6
+
+ - name: Build package
+ run: uv build
+
+ - name: Publish package
+ run: uv publish
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
new file mode 100644
index 0000000..ca80840
--- /dev/null
+++ b/.github/workflows/run_tests.yaml
@@ -0,0 +1,65 @@
+name: Testing
+
+on:
+ pull_request:
+ branches: [develop, main]
+
+jobs:
+ run-unit-tests:
+ strategy:
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ fail-fast: false
+
+ runs-on: ${{ matrix.os }}
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v5
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v6
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Setup uv
+ uses: astral-sh/setup-uv@v7
+
+ # - name: Create virtual environment
+ # run: uv venv .venv && source .venv/bin/activate
+
+ - name: Install modules
+ run: uv sync
+
+ - name: Run unit tests for Python ${{ matrix.python-version }}
+ run: uv run pytest tests/unit
+
+ run-integration-tests:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
+ fail-fast: false
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v5
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v6
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Setup uv
+ uses: astral-sh/setup-uv@v7
+
+ - name: Install modules
+ run: uv sync --extra redis --extra aiopika
+
+ - name: Setup Docker containers
+ run: docker compose -f ./docker-compose-tests.yml up --wait
+
+ - name: Run integration tests for Python ${{ matrix.python-version }}
+ run: uv run pytest tests/integration
diff --git a/README.md b/README.md
index aa1a64f..57abe86 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,172 @@
-[](https://pypi.org/project/taskiq-cancellation)
-[](https://pypi.org/project/taskiq-cancellation)
+
+

+
-# Task cancellation for taskiq
+[](https://pypi.org/project/taskiq-cancellation)
+[](https://pypi.org/project/taskiq-cancellation)
+
+**[taskiq-cancellation](https://pypi.org/project/taskiq-cancellation)** aims to be a drop-in task cancellation solution for taskiq as the original package doesn't provide a cancellation API.
+
+## Contents:
+
+- [Installation](#installation)
+- [Usage](#usage)
+ - [What is a cancellation backend?](#what-is-a-cancellation-backend)
+ - [Modular cancellation backend](#modular-cancellation-backend)
+ - [Available integrations](#available-integrations)
+ - [Level and edge cancellation](#level-and-edge-cancellation)
+ - [Retry middlewares with task cancellation](#retry-middlewares-with-task-cancellation)
+- [Development](#development)
+- [Contributing](#contributing)
## Installation
-```console
+This package can be install from PyPI with your package manager of choice.
+
+```bash
pip install taskiq-cancellation
+pipx install taskiq-cancellation
+poetry add taskiq-cancellation
+uv add taskiq-cancellation
+```
+
+taskiq-cancellation currently provides integrations with Redis and RabbitMQ that are installable with `redis` and `aiopika` extras respectfully.
+
+```bash
+pip install taskiq-cancellation[redis,aiopika]
```
## Usage
+
+To do task cancellation, you need to:
+
+1. Create a cancellation backend
+2. Wrap a function with `cancellable` decorator
+3. Cancel the task with `cancel(task_id)`
+
+```python
+broker = PubSubBroker(url).with_result_backend(RedisAsyncResultBackend(url))
+cancellation_backend = RedisCancellationBackend(url).with_broker(broker)
+
+@broker.task
+@cancellation_backend.cancellable
+async def sleep(seconds: int):
+ await asyncio.sleep(seconds)
+ print("Slept!") # Won't be printed on worker side because of the cancellation
+
+async def main():
+ await broker.startup()
+
+ task = await sleep.kiq(5)
+ await cancellation_backend.cancel(task.task_id)
+
+ await broker.shutdown()
+
+asyncio.run(main())
+```
+
+### What is a cancellation backend?
+
+**Cancellation backend** can be seen as combination of a broker and result backend for cancellation messages that works underneath taskiq's broker. Cancellation backend won't run tasks marked as cancelled and will listen for cancellation messages for already running tasks.
+
+
+

+
+
+### Modular cancellation backend
+
+To easily create cancellation backends taskiq-cancellation provides `ModularCancellationBackend`. Modular cancellation backend consists of two parts: state holder and notifier.
+
+- State holder is used to check for task cancellation status before running the task.
+- Notifier is used to listen for cancellation messages while running the task
+
+This allows to use any techonology for task cancellation. For example, if one uses SQL database and RabbitMQ message broker, they can make a custom state holder with SQL library of their choice and use provided RabbitMQ notifier.
+
+```python
+from taskiq_cancellation import ModularCancellationBackend
+from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder
+from taskiq_cancellation.notifiers.aiopika import AioPikaCancellationNotifier
+
+backend = ModularCancellationBackend(
+ RedisCancellationStateHolder("redis://localhost:6379"),
+ AioPikaCancellationNotifier("amqp://guest:guest@localhost:5672")
+)
+```
+
+### Available integrations
+
+taskiq-cancellation provides:
+
+- state holder for Redis (`RedisCancellationStateHolder`)
+- notifiers for Redis pub/sub (`PubSubCancellationNotifier`) and RabbitMQ (`AioPikaCancellationNotifier`)
+
+Also there are `NullCancellationStateHolder` and `NullCancellationNotifier` that do absolutely nothing, if there's no need to not check for task cancellation before starting the task or no need to listen for cancellation of already running tasks.
+
+### Level and edge cancellation
+
+By default, taskiq-cancellation uses [`anyio`](https://anyio.readthedocs.io/en/stable/) and its [level cancellation](anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics). Level cancellation raises a cancellation exception on **every** asynchronous wait in a function.
+
+As external libraries might not support level cancellation, task-cancellation also provides [edge cancellation]() via `asyncio`. Edge cancellation raises an exception only _once_. To enable it, add `cancellation_type=CancellationType.EDGE` parameter to `cancellable` decorator.
+
+> [!WARNING]
+> Currently edge cancellation is supported only for Python 3.11+ because it uses [`asyncio.TaskGroup`](https://docs.python.org/3/library/asyncio-task.html#asyncio.TaskGroup)
+
+Example:
+
+```python
+from sqlalchemy.ext.asyncio import AsyncSession
+from taskiq_cancellation import CancellationType
+
+@broker.task
+@cancellation_backend.cancellable(cancellation_type=CancellationType.EDGE)
+async def sleep(seconds: int):
+ session = AsyncSession(engine)
+
+ try:
+ async with session.begin():
+ await asyncio.sleep(seconds)
+ session.add(SleptFor(seconds))
+ except asyncio.CancelledError:
+ # Won't raise cancelled exception
+ await session.close()
+ raise
+```
+
+### Retry middlewares with task cancellation
+
+If you use `SimpleRetryMiddleware` or `SmartRetryMiddleware`, make sure to add `TaskCancellationException` to `types_of_exceptions` parameter to not trigger additional retries.
+
+```python
+from taskiq_cancellation.exceptions import TaskCancellationException
+
+broker = PubSubBroker(url)
+ .with_result_backend(RedisAsyncResultBackend(url))
+ .with_middlewares(
+ SimpleRetryMiddleware(
+ types_of_exceptions=[TaskCancellationException, ]
+ )
+ )
+```
+
+## Development
+
+For linting, ruff is used
+
+```bash
+ruff check
+ruff format
+```
+
+For testing, pytest is used
+
+```bash
+pytest tests/unit # Unit tests
+
+# Integration tests
+docker compose -f docker-compose-tests.yml up --wait
+pytest tests/integration
+```
+
+## Contributing
+
+If you have any issues with this package or have an idea for improvement, please don't hesitate to open an issue! This is my first open-source project so I would like to ask to be a little patient with me though 🙏
diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml
new file mode 100644
index 0000000..ffa0fbd
--- /dev/null
+++ b/docker-compose-tests.yml
@@ -0,0 +1,20 @@
+services:
+ redis:
+ image: redis:latest
+ ports:
+ - "6379:6379"
+
+ rabbitmq:
+ image: rabbitmq:latest
+ environment:
+ - RABBITMQ_DEFAULT_USER=guest
+ - RABBITMQ_DEFAULT_PASSWORD=guest
+ hostname: localhost
+ ports:
+ - "5672:5672"
+ healthcheck:
+ test: "rabbitmq-diagnostics check_running -q"
+ interval: 5s
+ timeout: 5s
+ retries: 10
+ start_period: 5s
\ No newline at end of file
diff --git a/examples/counter/README.md b/examples/counter/README.md
new file mode 100644
index 0000000..f3bc5be
--- /dev/null
+++ b/examples/counter/README.md
@@ -0,0 +1,63 @@
+# Example: Counter
+
+This example demonstrates use of Redis cancellation backend with Redis broker.
+
+`sleep(seconds)` task sleeps for `seconds` seconds. First task will finish fully after 5 seconds and second task will cancel after 2.5 seconds.
+
+## How to run
+
+1. Install dependencies
+
+```bash
+python -m venv venv
+source venv/bin/activate
+pip install -r requirements.txt
+```
+
+2. Launch redis server locally
+
+3. Launch worker
+
+```bash
+taskiq worker main:broker --workers=1
+```
+
+4. Launch client
+
+```bash
+python main.py
+```
+
+## Expected result
+
+Client side:
+
+```console
+Sending task and waiting 5 seconds...
+Sending task and waiting 2.5 seconds...
+Canceling task...
+```
+
+Worker side:
+
+```console
+[2025-11-11 22:45:12,072][taskiq.receiver.receiver][INFO ][worker-0] Executing task main:count with ID: e2acdc76da94435d963b561b80098c47
+1 Mississippi
+2 Mississippi
+3 Mississippi
+4 Mississippi
+5 Mississippi
+[2025-11-11 22:45:17,073][taskiq.receiver.receiver][INFO ][worker-0] Executing task main:count with ID: e094778df6de450d811c4701cfff608f
+1 Mississippi
+2 Mississippi
+3 Mississippi
+[2025-11-11 22:45:19,589][taskiq.receiver.receiver][ERROR ][worker-0] Exception found while executing function:
+Traceback (most recent call last):
+ File "P:\Scratches\taskiq-cancellation\.venv\lib\site-packages\taskiq\receiver\receiver.py", line 254, in run_task
+ returned = await target_future
+ File "P:\Scratches\taskiq-cancellation\src\taskiq_cancellation\abc\backend.py", line 190, in wrapper
+ return await level_handler(*args, **kwargs)
+ File "P:\Scratches\taskiq-cancellation\src\taskiq_cancellation\cancellation_handlers\level.py", line 104, in __call__
+ raise TaskCancellationException()
+taskiq_cancellation.exceptions.TaskCancellationException
+```
diff --git a/examples/counter/main.py b/examples/counter/main.py
index 5facc63..61f9789 100644
--- a/examples/counter/main.py
+++ b/examples/counter/main.py
@@ -1,12 +1,12 @@
import asyncio
from taskiq_redis import PubSubBroker, RedisAsyncResultBackend
-from taskiq_cancellation.integrations.redis import RedisCancellationBackend
+from taskiq_cancellation.backends.redis import RedisCancellationBackend
url = "redis://localhost"
broker = PubSubBroker(url).with_result_backend(RedisAsyncResultBackend(url))
-cancellation_backend = RedisCancellationBackend(url)
+cancellation_backend = RedisCancellationBackend(url).with_broker(broker)
@broker.task
@@ -20,11 +20,14 @@ async def count(up_to: int):
async def main():
await broker.startup()
+ print("Sending task and waiting 5 seconds...")
task = await count.kiq(5)
await asyncio.sleep(5)
+ print("Sending task and waiting 2.5 seconds...")
task = await count.kiq(5)
await asyncio.sleep(2.5)
+ print("Canceling task...")
await cancellation_backend.cancel(task.task_id)
await broker.shutdown()
diff --git a/imgs/backend_scheme.png b/imgs/backend_scheme.png
new file mode 100644
index 0000000..30fb5cb
Binary files /dev/null and b/imgs/backend_scheme.png differ
diff --git a/imgs/header.png b/imgs/header.png
new file mode 100644
index 0000000..6956864
Binary files /dev/null and b/imgs/header.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 5c1c6cf..f55868d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,31 +1,32 @@
-[build-system]
-requires = ["hatchling"]
-build-backend = "hatchling.build"
-
[project]
name = "taskiq-cancellation"
dynamic = ["version"]
-description = 'Task cancellation mechanism for taskiq'
+description = 'Task cancellation for taskiq'
readme = "README.md"
-requires-python = ">=3.8"
+requires-python = ">=3.9"
license = "MIT"
keywords = ["taskiq", "cancellation"]
authors = [{ name = "Alexander Starikov", email = "acherryjam@gmail.com" }]
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
-dependencies = ["taskiq"]
+dependencies = [
+ "async-timeout>=5.0.1 ; python_full_version < '3.11'",
+ "taskiq",
+ "typing-extensions>=4.15.0 ; python_full_version < '3.11'",
+]
[project.optional-dependencies]
-redis = ["redis~=3.0"]
+redis = ["redis~=6.0"]
aiopika = ["aio_pika"]
[project.urls]
@@ -33,6 +34,10 @@ Documentation = "https://github.com/ACherryJam/taskiq-cancellation#readme"
Issues = "https://github.com/ACherryJam/taskiq-cancellation/issues"
Source = "https://github.com/ACherryJam/taskiq-cancellation"
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
[tool.hatch.version]
path = "src/taskiq_cancellation/__about__.py"
@@ -41,6 +46,21 @@ extra-dependencies = ["mypy>=1.0.0"]
[tool.hatch.envs.types.scripts]
check = "mypy --install-types --non-interactive {args:src/taskiq_cancellation tests}"
+[tool.hatch.build.targets.sdist]
+only-include = ["src/taskiq_cancellation"]
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/taskiq_cancellation"]
+
+[tool.mypy]
+ignore_missing_imports = true
+exclude = [
+ # Added so that "mypy ." would work
+ "examples",
+ # Contains Python 3.11+ code, has to be excluded. Runtime checks don't work.
+ "src/taskiq_cancellation/cancellation_handlers/edge_3_11.py",
+]
+
[tool.coverage.run]
source_pkgs = ["taskiq_cancellation", "tests"]
branch = true
@@ -56,3 +76,22 @@ tests = ["tests", "*/taskiq-cancellation/tests"]
[tool.coverage.report]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
+
+[dependency-groups]
+dev = [
+ "mypy>=1.14.1",
+ "pytest>=8.3.5",
+ "pytest-asyncio>=0.24.0",
+ "ruff>=0.14.4",
+]
+
+[tool.ruff]
+# Edge cancellation is currently Python 3.11+ which causes linter to freak out
+# For such versioning cases tests must be written
+target-version = "py314"
+
+[[tool.uv.index]]
+name = "testpypi"
+url = "https://test.pypi.org/simple/"
+publish-url = "https://test.pypi.org/legacy/"
+explicit = true
diff --git a/src/taskiq_cancellation/__about__.py b/src/taskiq_cancellation/__about__.py
index d3ec452..f102a9c 100644
--- a/src/taskiq_cancellation/__about__.py
+++ b/src/taskiq_cancellation/__about__.py
@@ -1 +1 @@
-__version__ = "0.2.0"
+__version__ = "0.0.1"
diff --git a/src/taskiq_cancellation/__init__.py b/src/taskiq_cancellation/__init__.py
index 4ad9e80..914f0b2 100644
--- a/src/taskiq_cancellation/__init__.py
+++ b/src/taskiq_cancellation/__init__.py
@@ -1,8 +1,12 @@
-from .abc import *
-from .modular import ModularCancellationBackend
+from .abc import CancellationBackend
+from .backends.modular import ModularCancellationBackend
+from .backends.in_memory import InMemoryCancellationBackend
+from .cancellation_handlers.cancellation_type import CancellationType
+from .exceptions import TaskCancellationException
__all__ = [
- "ModularCancellationBackend"
+ "CancellationBackend", "ModularCancellationBackend",
+ "InMemoryCancellationBackend", "CancellationType", "TaskCancellationException"
]
diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py
index cc8b52c..45239e2 100644
--- a/src/taskiq_cancellation/abc/__init__.py
+++ b/src/taskiq_cancellation/abc/__init__.py
@@ -1,10 +1,12 @@
from .backend import CancellationBackend
from .notifier import CancellationNotifier
from .state_holder import CancellationStateHolder
+from .started_listening_event import StartedListeningEvent
__all__ = [
"CancellationBackend",
"CancellationNotifier",
- "CancellationStateHolder"
+ "CancellationStateHolder",
+ "StartedListeningEvent",
]
diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py
index 4f5ff15..6f20b5e 100644
--- a/src/taskiq_cancellation/abc/backend.py
+++ b/src/taskiq_cancellation/abc/backend.py
@@ -1,98 +1,213 @@
import abc
-import asyncio
-import traceback
-from typing import Callable, Annotated, ParamSpec, TypeVar, Awaitable
+import sys
+import inspect
+from typing import Callable, Annotated, overload, Optional, cast, Union
-import anyio
-from anyio.abc import TaskStatus
-from taskiq import Context, TaskiqDepends
+if sys.version_info >= (3, 11):
+ from typing import Self, ParamSpec, TypeVar
+else:
+ from typing_extensions import Self, ParamSpec, TypeVar
-from ..utils import combines
-from ..exceptions import TaskCancellationException
+from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState
+from taskiq_cancellation.utils import combines
+from taskiq_cancellation.cancellation_handlers import (
+ CancellationType,
+ LevelCancellationHandler,
+ EdgeCancellationHandler,
+)
-P = ParamSpec("P")
-R = TypeVar("R")
+from .started_listening_event import StartedListeningEvent
+
+
+Params = ParamSpec("Params")
+Result = TypeVar("Result")
class CancellationBackend(abc.ABC):
- @abc.abstractmethod
+ """
+ Base class for cancellation backend
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.broker: Union[AsyncBroker, None] = None
+
+ @abc.abstractmethod
async def is_cancelled(self, task_id: str) -> bool:
+ """
+ Checks if a task with task id of *task_id* is set to be cancelled
+
+ :param task_id: task id to check
+ :type task_id: str
+ :returns: task cancellation state
+ :rtype: bool
+ """
pass
@abc.abstractmethod
async def cancel(self, task_id: str) -> None:
+ """
+ Cancels a task with task id of *task_id*
+
+ :param task_id: id of the task to cancel
+ :type task_id: str
+ """
pass
@abc.abstractmethod
async def listen_for_cancellation(
- self, task_id: str, started_listening_task_status: TaskStatus
+ self, task_id: str, started_listening_event: StartedListeningEvent
) -> None:
+ """
+ Listens for cancellation messages and raises :ref:`TaskCancellationException` when
+ receives :ref:`CancellationMessage` with same id as *task_id*.
+
+ This function is used in :func:`cancellable` decorator.
+ Call `started_listening_task_status.started()` when the listener is ready
+ to receive messages.
+
+ :param task_id: id of task that will be listened for
+ :type task_id: str
+ :param started_listening_task_status:
+ :type started_listening_task_status: anyio.abc.TaskStatus
+ """
pass
-
- def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
- # Executor type depends on receiver configuration which we
- # can't access in any way
- if not asyncio.iscoroutinefunction(task):
- raise ValueError("Can't cancel synchronous function")
-
- @combines(task)
- async def wrapper(
- *args,
- __taskiq_context: Annotated[Context, TaskiqDepends()],
- **kwargs
- ):
- task_id = __taskiq_context.message.task_id
- result = None
-
- listener_exception: Exception | None = None
- task_exception: Exception | None = None
- cancelled_by_request: bool = False
-
- async with anyio.create_task_group() as group:
- async def listen_for_cancellation(
- task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
- ):
- nonlocal listener_exception, cancelled_by_request
-
- try:
- await self.listen_for_cancellation(task_id, task_status)
- except TaskCancellationException:
- cancelled_by_request = True
- except anyio.get_cancelled_exc_class():
- pass
- except Exception as e:
- listener_exception = e
- finally:
- group.cancel_scope.cancel()
-
- async def call_task():
- nonlocal result, task_exception
-
- try:
- result = await task(*args, **kwargs)
- except anyio.get_cancelled_exc_class():
- pass
- except Exception as e:
- task_exception = e
- finally:
- group.cancel_scope.cancel()
-
- # Listen before checking for cancellation in database so
- # the message won't get lost in non-persistent queues
- await group.start(listen_for_cancellation)
- if await self.is_cancelled(task_id):
- cancelled_by_request = True
- group.cancel_scope.cancel()
- else:
- group.start_soon(call_task)
-
- if task_exception is not None:
- raise task_exception
- elif cancelled_by_request:
- raise TaskCancellationException()
- elif listener_exception is not None:
- raise listener_exception
- else:
- return result
- return wrapper
+
+ async def startup(self) -> None:
+ """
+ Starts up cancellation backend
+
+ Triggered only if backend has a broker set. To set a broker use :func:`with_broker`.
+ """
+ pass
+
+ async def shutdown(self) -> None:
+ """Shuts down cancellation backend
+
+ Triggered only if backend has a broker set. To set a broker use :ref:`with_broker`.
+ """
+ pass
+
+ def with_broker(self, broker: AsyncBroker) -> Self:
+ """
+ Set a broker and return updated cancellation backend
+
+ Sets up startup and shutdown event handlers for backend's startup
+ and shutdown methods respectfully
+
+ :param broker: new broker
+ :type broker: AsyncBroker
+ :returns: self
+ """
+ if self.broker is not None:
+ self.broker.event_handlers[TaskiqEvents.CLIENT_STARTUP].remove(
+ self._broker_startup_handler
+ )
+ self.broker.event_handlers[TaskiqEvents.WORKER_STARTUP].remove(
+ self._broker_startup_handler
+ )
+ self.broker.event_handlers[TaskiqEvents.CLIENT_SHUTDOWN].remove(
+ self._broker_shutdown_handler
+ )
+ self.broker.event_handlers[TaskiqEvents.WORKER_SHUTDOWN].remove(
+ self._broker_shutdown_handler
+ )
+
+ self.broker = broker
+ self.broker.add_event_handler(
+ TaskiqEvents.CLIENT_STARTUP, self._broker_startup_handler
+ )
+ self.broker.add_event_handler(
+ TaskiqEvents.WORKER_STARTUP, self._broker_startup_handler
+ )
+ self.broker.add_event_handler(
+ TaskiqEvents.CLIENT_SHUTDOWN, self._broker_shutdown_handler
+ )
+ self.broker.add_event_handler(
+ TaskiqEvents.WORKER_SHUTDOWN, self._broker_shutdown_handler
+ )
+
+ return self
+
+ @overload
+ def cancellable(
+ self, cancellation_type: Callable[Params, Result]
+ ) -> Callable[Params, Result]:
+ pass
+
+ @overload
+ def cancellable(
+ self, cancellation_type: Optional[CancellationType] = None
+ ) -> Callable[[Callable[Params, Result]], Callable[Params, Result]]:
+ pass
+
+ def cancellable(self, cancellation_type=None):
+ """
+ Decorator that makes funcion cancellable
+
+ This decorator makes a new function that creates two tasks in :ref:`anyio.TaskGroup`:
+ 1. Cancellation message listener (uses :ref:`listen_for_cancellation`)
+ 2. Wrapped function
+
+ - Returns function's result/exception if it finishes successfully/unsuccessfully
+ - Raises :ref:`TaskCancellationException` if listener task receives cancellation message
+ - If listener task raises an exception, task is cancelled and exception is propogated upwards
+
+ :param cancellation_type: type of cancellation used
+ :type cancellation_type: CancellationType
+ :returns: Cancellable task function
+ """
+
+ defaults = {"cancellation_type": CancellationType.LEVEL}
+
+ def make_decorator(cancellation_type: CancellationType):
+ def decorator(
+ task: Callable[Params, Result], /
+ ) -> Callable[Params, Result]:
+ # Executor type depends on receiver configuration which we can't accessed in any way
+ if not inspect.iscoroutinefunction(task):
+ raise ValueError("Can't cancel synchronous function")
+
+ @combines(task)
+ async def wrapper(
+ *args,
+ __taskiq_context: Annotated[Context, TaskiqDepends(Context)] = None, # type: ignore
+ **kwargs,
+ ) -> Result:
+ if __taskiq_context is None:
+ # Ran the function directly, without kiq
+ return await task(*args, **kwargs)
+
+ task_id = __taskiq_context.message.task_id
+
+ if cancellation_type is CancellationType.EDGE:
+ edge_handler = EdgeCancellationHandler(self, task, task_id)
+ return await edge_handler(*args, **kwargs)
+ elif cancellation_type is CancellationType.LEVEL:
+ level_handler = LevelCancellationHandler(self, task, task_id)
+ return await level_handler(*args, **kwargs)
+ else:
+ raise ValueError(
+ f"Unknown cancellation type: {cancellation_type!r}"
+ )
+
+ # Wrapper adds a key-word only param with default value
+ casted_wrapper = cast(Callable[Params, Result], wrapper)
+ return casted_wrapper
+
+ return decorator
+
+ if callable(cancellation_type):
+ return make_decorator(**defaults)(cancellation_type)
+ else:
+ return make_decorator(
+ cancellation_type=cancellation_type or defaults["cancellation_type"]
+ )
+
+ async def _broker_startup_handler(self, _: TaskiqState) -> None:
+ await self.startup()
+
+ async def _broker_shutdown_handler(self, _: TaskiqState) -> None:
+ await self.shutdown()
diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py
index 0faaa91..60b1d14 100644
--- a/src/taskiq_cancellation/abc/notifier.py
+++ b/src/taskiq_cancellation/abc/notifier.py
@@ -1,20 +1,67 @@
+import sys
import abc
-from anyio.abc import TaskStatus
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.serializers import JSONSerializer
+from .started_listening_event import StartedListeningEvent
+
class CancellationNotifier(abc.ABC):
- def __init__(self, serializer: TaskiqSerializer = JSONSerializer()):
- self.serializer = serializer
-
+ """Receives cancellation messages and notifies listeners of these messages"""
+
+ def __init__(self):
+ self.serializer = JSONSerializer()
+
+ async def startup(self) -> None:
+ """Starts up cancellation notifier"""
+ pass
+
+ async def shutdown(self) -> None:
+ """Shuts down cancellation notifier"""
+ pass
+
@abc.abstractmethod
async def cancel(self, task_id: str) -> None:
+ """
+ Sends a cancellation message of a task with task id of *task_id*
+
+ :param task_id: id of the task to cancel
+ :type task_id: str
+ """
pass
@abc.abstractmethod
async def listen_for_cancellation(
- self, task_id: str, started_listening_task_status: TaskStatus
+ self, task_id: str, started_listening_event: StartedListeningEvent
) -> None:
+ """
+ Listens for cancellation messages and raises :ref:`TaskCancellationException` when
+ receives :ref:`CancellationMessage` with same id as *task_id*.
+
+ This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`.
+ Call `started_listening_event.set()` when the listener is ready to receive messages.
+
+ :param task_id: id of task that will be listened for
+ :type task_id: str
+ :param started_listening_event: "listener started listening" confirmation event
+ :type started_listening_event: StartedListeningEvent
+ """
pass
+
+ def with_serializer(self, serializer: TaskiqSerializer) -> Self:
+ """
+ Sets a serializer to be used by the notifier
+
+ :param serializer: serializer for cancellation messages
+ :type serializer: TaskiqSerializer
+ :return: self
+ """
+ self.serializer = serializer
+
+ return self
diff --git a/src/taskiq_cancellation/abc/started_listening_event.py b/src/taskiq_cancellation/abc/started_listening_event.py
new file mode 100644
index 0000000..4e43acf
--- /dev/null
+++ b/src/taskiq_cancellation/abc/started_listening_event.py
@@ -0,0 +1,22 @@
+import abc
+
+
+class StartedListeningEvent(abc.ABC):
+ """
+ A confirmation event for listeners to mark that they started listening to messages. API is
+ similar to :ref:`asyncio.Event`.
+
+ This is needed for different cancellation types:
+ - Level cancellation uses :ref:`anyio.abc.TaskStatus`
+ - Edge cancellation uses :ref:`asyncio.Event`
+ """
+
+ @abc.abstractmethod
+ async def set(self):
+ """Sets the event"""
+ pass
+
+ @abc.abstractmethod
+ async def wait(self):
+ """Waits for the event to be set"""
+ pass
diff --git a/src/taskiq_cancellation/abc/state_holder.py b/src/taskiq_cancellation/abc/state_holder.py
index 45df346..bb2a707 100644
--- a/src/taskiq_cancellation/abc/state_holder.py
+++ b/src/taskiq_cancellation/abc/state_holder.py
@@ -2,10 +2,34 @@
class CancellationStateHolder(abc.ABC):
+ """Holds cancellation state of Taskiq tasks"""
+
@abc.abstractmethod
async def cancel(self, task_id: str) -> None:
+ """
+ Sets a state of task with task id of *task_id* to be cancelled
+
+ :param task_id: id of the task to cancel
+ :type task_id: str
+ """
pass
@abc.abstractmethod
async def is_cancelled(self, task_id: str) -> bool:
+ """
+ Checks if a task with task id of *task_id* is set to be cancelled
+
+ :param task_id: task id to check
+ :type task_id: str
+ :returns: task cancellation state
+ :rtype: bool
+ """
+ pass
+
+ async def startup(self) -> None:
+ """Starts up cancellation state holder"""
+ pass
+
+ async def shutdown(self) -> None:
+ """Shuts down cancellation state holder"""
pass
diff --git a/src/taskiq_cancellation/backends/__init__.py b/src/taskiq_cancellation/backends/__init__.py
new file mode 100644
index 0000000..5c65c46
--- /dev/null
+++ b/src/taskiq_cancellation/backends/__init__.py
@@ -0,0 +1,8 @@
+from .modular import ModularCancellationBackend
+from .in_memory import InMemoryCancellationBackend
+
+
+__all__ = [
+ "ModularCancellationBackend",
+ "InMemoryCancellationBackend"
+]
diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py
new file mode 100644
index 0000000..795ac6e
--- /dev/null
+++ b/src/taskiq_cancellation/backends/in_memory.py
@@ -0,0 +1,18 @@
+from taskiq_cancellation.notifiers.in_memory import InMemoryCancellationNotifier
+from taskiq_cancellation.state_holders.in_memory import InMemoryCancellationStateHolder
+
+from .modular import ModularCancellationBackend
+
+
+class InMemoryCancellationBackend(ModularCancellationBackend):
+ """
+ Cancellation backend that stores state and notifications in memory
+
+ Useful for testing purposes
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(
+ state_holder=InMemoryCancellationStateHolder(**kwargs),
+ notifier=InMemoryCancellationNotifier(**kwargs),
+ )
diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py
new file mode 100644
index 0000000..c65c00c
--- /dev/null
+++ b/src/taskiq_cancellation/backends/modular.py
@@ -0,0 +1,70 @@
+import sys
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
+from taskiq.abc.serializer import TaskiqSerializer
+
+from taskiq_cancellation.abc import (
+ CancellationBackend,
+ CancellationNotifier,
+ CancellationStateHolder,
+ StartedListeningEvent,
+)
+
+import anyio
+
+
+class ModularCancellationBackend(CancellationBackend):
+ """
+ Modular cancellation backend made up of :class:`CancellationStateHolder`
+ and :class:`CancellationNotifier`
+
+ - `CancellationStateHolder` stores cancellation state and blocks the task from being run.
+ - `CancellationNotifier` receives cancellation messages and cancels already running tasks.
+ """
+
+ def __init__(
+ self, state_holder: CancellationStateHolder, notifier: CancellationNotifier
+ ):
+ super().__init__()
+
+ self.notifier: CancellationNotifier = notifier
+ self.state_holder: CancellationStateHolder = state_holder
+
+ async def is_cancelled(self, task_id: str) -> bool:
+ return await self.state_holder.is_cancelled(task_id)
+
+ async def cancel(self, task_id: str):
+ async with anyio.create_task_group() as group:
+ group.start_soon(self.state_holder.cancel, task_id)
+ group.start_soon(self.notifier.cancel, task_id)
+
+ async def listen_for_cancellation(
+ self, task_id: str, started_listening_event: StartedListeningEvent
+ ):
+ await self.notifier.listen_for_cancellation(task_id, started_listening_event)
+
+ async def startup(self) -> None:
+ await super().startup()
+ await self.state_holder.startup()
+ await self.notifier.startup()
+
+ async def shutdown(self) -> None:
+ await super().shutdown()
+ await self.state_holder.shutdown()
+ await self.notifier.shutdown()
+
+ def with_serializer(self, serializer: TaskiqSerializer) -> Self:
+ """
+ Sets a serializer to be used by the notifier
+
+ :param serializer: serializer for cancellation messages
+ :type serializer: TaskiqSerializer
+ :return: self
+ """
+ self.notifier = self.notifier.with_serializer(serializer)
+
+ return self
diff --git a/src/taskiq_cancellation/backends/redis.py b/src/taskiq_cancellation/backends/redis.py
new file mode 100644
index 0000000..3870e32
--- /dev/null
+++ b/src/taskiq_cancellation/backends/redis.py
@@ -0,0 +1,12 @@
+from taskiq_cancellation.backends.modular import ModularCancellationBackend
+
+from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier
+from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder
+
+
+class RedisCancellationBackend(ModularCancellationBackend):
+ def __init__(self, url: str, **kwargs) -> None:
+ super().__init__(
+ RedisCancellationStateHolder(url, **kwargs),
+ PubSubCancellationNotifier(url, **kwargs),
+ )
diff --git a/src/taskiq_cancellation/cancellation_handlers/__init__.py b/src/taskiq_cancellation/cancellation_handlers/__init__.py
new file mode 100644
index 0000000..a228c8d
--- /dev/null
+++ b/src/taskiq_cancellation/cancellation_handlers/__init__.py
@@ -0,0 +1,12 @@
+import sys
+
+from .cancellation_type import CancellationType
+from .level import LevelCancellationHandler
+
+if sys.version_info >= (3, 11):
+ from .edge_3_11 import EdgeCancellationHandler
+else:
+ from .edge_non_supported import EdgeCancellationHandler
+
+
+__all__ = ["CancellationType", "LevelCancellationHandler", "EdgeCancellationHandler"]
diff --git a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py
new file mode 100644
index 0000000..9714e67
--- /dev/null
+++ b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py
@@ -0,0 +1,8 @@
+import enum
+
+
+class CancellationType(str, enum.Enum):
+ """Type of cancellation used by the backend"""
+
+ EDGE = "edge"
+ LEVEL = "level"
diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py
new file mode 100644
index 0000000..4bf66b5
--- /dev/null
+++ b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py
@@ -0,0 +1,136 @@
+# FIXME: bunch of issues because of Python 3.11+ exclusivity
+#
+# Edge cancellation handler is using asyncio.TaskGroup which was introduced in 3.11 and
+# uses expect-star syntax that was also introduced in the same version
+# - mypy has to ignore this file because it can't finish static parsing
+# - ruff's python version has to be set at 3.11+ so it wouldn't complain
+#
+# I'm not sure how to mitigate these issues. Maybe this can be put in a separate module somehow
+# and then integrated? Maybe this can be rewritten to not use TaskGroup (probably easier to do)?
+
+import sys
+import logging
+import asyncio
+from collections.abc import Coroutine
+from typing import Callable, TYPE_CHECKING, Generic, Any
+
+if sys.version_info >= (3, 11):
+ from typing import ParamSpec, TypeVar
+else:
+ from typing_extensions import ParamSpec, TypeVar
+
+from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent
+from taskiq_cancellation.exceptions import TaskCancellationException
+from taskiq_cancellation.utils import StopTaskGroupException
+
+if TYPE_CHECKING:
+ from taskiq_cancellation.abc.backend import CancellationBackend
+
+
+Params = ParamSpec("Params")
+Result = TypeVar("Result")
+
+
+class EdgeCancellationHandler(Generic[Params, Result]):
+ """
+ Wrapper around a task function that handles cancellation
+
+ Uses edge cancellation provided by asyncio. That means :ref:`asyncio.CancelledError` is
+ raised only once for the task.
+ Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation
+
+ Currently is supported in Python 3.11+ due to using :ref:`asyncio.TaskGroup`.
+ """
+
+ class ListeningEvent(StartedListeningEvent):
+ def __init__(self) -> None:
+ self.event = asyncio.Event()
+
+ async def set(self):
+ loop = asyncio.get_running_loop()
+ loop.call_soon_threadsafe(self.event.set)
+
+ async def wait(self):
+ await self.event.wait()
+
+ def __init__(
+ self,
+ backend: "CancellationBackend",
+ task: Callable[Params, Coroutine[Any, Any, Result]],
+ task_id: str,
+ ):
+ self.backend = backend
+ self.task = task
+ self.task_id = task_id
+
+ async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result:
+ result: Result = None # type: ignore
+
+ listener_exception: Exception | None = None
+ task_exception: Exception | None = None
+ cancelled_by_request: bool = False
+
+ async def listen_for_cancellation(event: StartedListeningEvent):
+ nonlocal listener_exception, cancelled_by_request
+
+ try:
+ await self.backend.listen_for_cancellation(self.task_id, event)
+ except TaskCancellationException:
+ cancelled_by_request = True
+ raise
+ except asyncio.CancelledError:
+ raise
+ except Exception as e:
+ listener_exception = e
+ raise
+
+ async def call_task():
+ nonlocal result, task_exception
+
+ try:
+ result = await self.task(*args, **kwargs)
+ except asyncio.CancelledError:
+ raise
+ except Exception as e:
+ task_exception = e
+ raise
+
+ try:
+ async with asyncio.TaskGroup() as tg:
+ # Listen before checking for cancellation in state holder
+ # so the message won't get lost in non-persistent queues
+ event = self.ListeningEvent()
+ tg.create_task(listen_for_cancellation(event))
+ await event.wait()
+
+ if await self.backend.is_cancelled(self.task_id):
+ cancelled_by_request = True
+ raise StopTaskGroupException()
+
+ task_task = asyncio.create_task(call_task())
+ await task_task
+ if not task_task.cancelled():
+ raise StopTaskGroupException()
+ except* StopTaskGroupException:
+ pass
+ except* Exception as exc_group:
+ uncaught_exceptions = list(
+ filter(
+ lambda e: e == task_exception or e == listener_exception,
+ exc_group.exceptions,
+ )
+ )
+
+ if uncaught_exceptions:
+ logging.log(logging.ERROR, "Uncaught exceptions in TaskGroup")
+ for e in uncaught_exceptions:
+ logging.exception(e)
+
+ if task_exception is not None:
+ raise task_exception
+ elif cancelled_by_request:
+ raise TaskCancellationException()
+ elif listener_exception is not None:
+ raise listener_exception
+ else:
+ return result
diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py
new file mode 100644
index 0000000..03d06fd
--- /dev/null
+++ b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py
@@ -0,0 +1,34 @@
+import sys
+from typing import Callable, TYPE_CHECKING, Coroutine, Generic, Any
+
+if sys.version_info >= (3, 11):
+ from typing import ParamSpec, TypeVar
+else:
+ from typing_extensions import ParamSpec, TypeVar
+
+if TYPE_CHECKING:
+ from taskiq_cancellation.abc.backend import CancellationBackend
+
+
+Params = ParamSpec("Params")
+Result = TypeVar("Result")
+
+
+class EdgeCancellationHandler(Generic[Params, Result]):
+ """
+ Wrapper around a task function that handles cancellation
+
+ Uses edge cancellation provided by asyncio. Currently is supported in Python 3.11+ due
+ to using :ref:`asyncio.TaskGroup`.
+ """
+
+ def __init__(
+ self,
+ backend: "CancellationBackend",
+ task: Callable[Params, Coroutine[Any, Any, Result]],
+ task_id: str,
+ ) -> None:
+ raise NotImplementedError("Edge cancellation is not supported for Python <3.11")
+
+ async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result:
+ raise NotImplementedError("Edge cancellation is not supported for Python <3.11")
diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py
new file mode 100644
index 0000000..2cfe71a
--- /dev/null
+++ b/src/taskiq_cancellation/cancellation_handlers/level.py
@@ -0,0 +1,110 @@
+import sys
+from collections.abc import Coroutine
+from typing import Callable, TYPE_CHECKING, Generic, Union, cast, Any
+
+if sys.version_info >= (3, 11):
+ from typing import ParamSpec, TypeVar
+else:
+ from typing_extensions import ParamSpec, TypeVar
+
+import anyio
+from anyio.abc import TaskStatus
+
+from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent
+from taskiq_cancellation.exceptions import TaskCancellationException
+
+if TYPE_CHECKING:
+ from taskiq_cancellation.abc.backend import CancellationBackend
+
+
+Params = ParamSpec("Params")
+Result = TypeVar("Result")
+
+
+class LevelCancellationHandler(Generic[Params, Result]):
+ """
+ Wrapper around a task function that handles cancellation
+
+ Uses level cancellation provided by anyio. That means cancellation exception is raised
+ on every await in the coroutine.
+ Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics
+ """
+
+ class ListeningEvent(StartedListeningEvent):
+ def __init__(self, task_status: TaskStatus) -> None:
+ self.task_status = task_status
+
+ async def set(self):
+ self.task_status.started()
+
+ async def wait(self):
+ # Can ignore, won't execute further before task status is set
+ pass
+
+ def __init__(
+ self,
+ backend: "CancellationBackend",
+ task: Callable[Params, Coroutine[Any, Any, Result]],
+ task_id: str,
+ ):
+ self.backend = backend
+ self.task = task
+ self.task_id = task_id
+
+ async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result:
+ result: Union[Result, None] = None
+
+ listener_exception: Union[Exception, None] = None
+ task_exception: Union[Exception, None] = None
+ cancelled_by_request: bool = False
+
+ async with anyio.create_task_group() as group:
+
+ async def listen_for_cancellation(
+ task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
+ ):
+ nonlocal listener_exception, cancelled_by_request
+
+ event = self.ListeningEvent(task_status)
+ try:
+ await self.backend.listen_for_cancellation(self.task_id, event)
+ except TaskCancellationException:
+ cancelled_by_request = True
+ except anyio.get_cancelled_exc_class():
+ pass
+ except Exception as e:
+ listener_exception = e
+ finally:
+ group.cancel_scope.cancel()
+
+ async def call_task():
+ nonlocal result, task_exception
+
+ try:
+ result = await self.task(*args, **kwargs)
+ except anyio.get_cancelled_exc_class():
+ pass
+ except Exception as e:
+ task_exception = e
+ finally:
+ group.cancel_scope.cancel()
+
+ # Listen before checking for cancellation in state holder
+ # so the message won't get lost in non-persistent queues
+ await group.start(listen_for_cancellation)
+ if await self.backend.is_cancelled(self.task_id):
+ cancelled_by_request = True
+ group.cancel_scope.cancel()
+ else:
+ group.start_soon(call_task)
+
+ if task_exception is not None:
+ raise task_exception
+ elif cancelled_by_request:
+ raise TaskCancellationException()
+ elif listener_exception is not None:
+ raise listener_exception
+ else:
+ # If the task is finished, it is definitely not None
+ result = cast(Result, result)
+ return result
diff --git a/src/taskiq_cancellation/integrations/aiopika/__init__.py b/src/taskiq_cancellation/integrations/aiopika/__init__.py
deleted file mode 100644
index 9d0488c..0000000
--- a/src/taskiq_cancellation/integrations/aiopika/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .notifier import AioPikaNotifier
-
-
-__all__ = [
- "AioPikaNotifier"
-]
diff --git a/src/taskiq_cancellation/integrations/aiopika/notifier.py b/src/taskiq_cancellation/integrations/aiopika/notifier.py
deleted file mode 100644
index b79ad78..0000000
--- a/src/taskiq_cancellation/integrations/aiopika/notifier.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import time
-
-import aio_pika
-from anyio.abc import TaskStatus
-from taskiq.abc.serializer import TaskiqSerializer
-from taskiq.serializers import JSONSerializer
-from taskiq.compat import model_dump, model_validate
-
-from taskiq_cancellation.abc import CancellationNotifier
-from taskiq_cancellation.exceptions import TaskCancellationException
-from taskiq_cancellation.message import CancellationMessage
-
-
-class AioPikaNotifier(CancellationNotifier):
- EXCHANGE_NAME = "__taskiq_cancellation"
-
- def __init__(self, url: str, serializer: TaskiqSerializer = JSONSerializer()):
- super().__init__(serializer)
-
- self.url: str = url
-
- async def cancel(self, task_id: str) -> None:
- timestamp = time.time()
- connection = await aio_pika.connect_robust(
- self.url
- )
- channel = await connection.channel()
-
- exchange = await channel.declare_exchange(
- self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True
- )
-
- await exchange.publish(
- aio_pika.Message(body=self.serializer.dumpb(
- model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp))
- )),
- routing_key=""
- )
-
- async def listen_for_cancellation(
- self, task_id: str, started_listening_task_status: TaskStatus
- ) -> None:
- connection = await aio_pika.connect_robust(
- self.url
- )
- channel = await connection.channel()
-
- exchange = await channel.declare_exchange(
- self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True
- )
- queue = await channel.declare_queue(exclusive=True)
- await queue.bind(exchange)
-
- started_listening_task_status.started()
-
- async with queue.iterator() as queue_iter:
- async for message in queue_iter:
- cancellation_message = model_validate(
- CancellationMessage,
- self.serializer.loadb(message.body)
- )
-
- if cancellation_message.task_id == task_id:
- raise TaskCancellationException()
- await message.ack()
diff --git a/src/taskiq_cancellation/integrations/redis/__init__.py b/src/taskiq_cancellation/integrations/redis/__init__.py
deleted file mode 100644
index fe38b68..0000000
--- a/src/taskiq_cancellation/integrations/redis/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from .notifier import PubSubCancellationNotifier
-from .state_holder import RedisCancellationStateHolder
-from .backend import RedisCancellationBackend
-
-
-__all__ = [
- "PubSubCancellationNotifier",
- "RedisCancellationStateHolder",
- "RedisCancellationBackend"
-]
diff --git a/src/taskiq_cancellation/integrations/redis/backend.py b/src/taskiq_cancellation/integrations/redis/backend.py
deleted file mode 100644
index ad038d1..0000000
--- a/src/taskiq_cancellation/integrations/redis/backend.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from typing import Type
-
-from taskiq_cancellation.modular import ModularCancellationBackend
-
-from .notifier import PubSubCancellationNotifier
-from .state_holder import RedisCancellationStateHolder
-
-
-class RedisCancellationBackend(ModularCancellationBackend):
- def __init__(
- self,
- url: str,
- state_holder: Type = RedisCancellationStateHolder,
- notifier: Type = PubSubCancellationNotifier,
- **connection_kwargs
- ) -> None:
- super().__init__(
- state_holder(url, **connection_kwargs),
- PubSubCancellationNotifier(url, **connection_kwargs)
- )
diff --git a/src/taskiq_cancellation/integrations/redis/notifier.py b/src/taskiq_cancellation/integrations/redis/notifier.py
deleted file mode 100644
index 5ed24fb..0000000
--- a/src/taskiq_cancellation/integrations/redis/notifier.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import time
-
-from anyio.abc import TaskStatus
-import redis.asyncio as redis
-from taskiq.compat import model_dump, model_validate
-
-from taskiq_cancellation.abc import CancellationNotifier
-from taskiq_cancellation.exceptions import TaskCancellationException
-from taskiq_cancellation.message import CancellationMessage
-
-
-class PubSubCancellationNotifier(CancellationNotifier):
- CHANNEL_NAME = "__taskiq_cancellation_notifications"
-
- def __init__(self, url: str, **connection_kwargs) -> None:
- super().__init__()
- self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs)
-
- async def cancel(self, task_id: str) -> None:
- timestamp = time.time()
-
- async with redis.Redis(connection_pool=self.connection_pool) as conn:
- await conn.publish(
- self.CHANNEL_NAME,
- self.serializer.dumpb(
- model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp))
- )
- )
-
- async def listen_for_cancellation(
- self, task_id: str, started_listening_task_status: TaskStatus
- ) -> None:
- async with redis.Redis(connection_pool=self.connection_pool) as conn:
- pubsub = conn.pubsub()
- await pubsub.subscribe(self.CHANNEL_NAME)
-
- started_listening_task_status.started()
-
- while True:
- message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None)
-
- if message is None:
- continue
- if message["type"] != "message":
- continue
-
- cancellation_message = model_validate(
- CancellationMessage,
- self.serializer.loadb(message['data'])
- )
- if cancellation_message.task_id == task_id:
- raise TaskCancellationException()
diff --git a/src/taskiq_cancellation/modular.py b/src/taskiq_cancellation/modular.py
deleted file mode 100644
index bead650..0000000
--- a/src/taskiq_cancellation/modular.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from .abc import *
-
-import anyio
-from anyio.abc import TaskStatus
-
-
-class ModularCancellationBackend(CancellationBackend):
- def __init__(
- self,
- state_holder: CancellationStateHolder,
- notifier: CancellationNotifier
- ):
- self.notifier: CancellationNotifier = notifier
- self.state_holder: CancellationStateHolder = state_holder
-
- async def is_cancelled(self, task_id: str) -> bool:
- return await self.state_holder.is_cancelled(task_id)
-
- async def cancel(self, task_id: str):
- async with anyio.create_task_group() as group:
- group.start_soon(self.state_holder.cancel, task_id)
- group.start_soon(self.notifier.cancel, task_id)
-
- async def listen_for_cancellation(
- self, task_id: str, started_listening_task_status: TaskStatus[None]
- ):
- await self.notifier.listen_for_cancellation(task_id, started_listening_task_status)
diff --git a/src/taskiq_cancellation/notifiers/__init__.py b/src/taskiq_cancellation/notifiers/__init__.py
new file mode 100644
index 0000000..9a097d6
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/__init__.py
@@ -0,0 +1,8 @@
+from .null import NullCancellationNotifier
+from .in_memory import InMemoryCancellationNotifier
+
+
+__all__ = [
+ "NullCancellationNotifier",
+ "InMemoryCancellationNotifier"
+]
diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py
new file mode 100644
index 0000000..2432e56
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/aiopika.py
@@ -0,0 +1,76 @@
+import time
+import asyncio
+
+import aio_pika
+from taskiq.compat import model_dump, model_validate
+
+from taskiq_cancellation.message import CancellationMessage
+
+from .queue import QueueCancellationNotifier
+
+
+class AioPikaNotifier(QueueCancellationNotifier):
+ """Notifier for RabbitMQ using aio-pika"""
+
+ EXCHANGE_NAME = "__taskiq_cancellation"
+
+ def __init__(self, url: str, **connection_kwargs):
+ """
+ Creates AioPika notifier
+
+ :param url: url to rabbitmq
+ :type url: str
+ :param connection_kwargs: arguments for :ref:`aio_pika.connect_robust`
+ """
+ super().__init__()
+
+ self.url: str = url
+ self.connection_kwargs = connection_kwargs
+
+ async def cancel(self, task_id: str) -> None:
+ timestamp = time.time()
+
+ connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs)
+
+ async with connection:
+ channel = await connection.channel()
+
+ exchange = await channel.declare_exchange(
+ self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True
+ )
+
+ await exchange.publish(
+ aio_pika.Message(
+ body=self.serializer.dumpb(
+ model_dump(
+ CancellationMessage(task_id=task_id, timestamp=timestamp)
+ )
+ )
+ ),
+ routing_key="",
+ )
+
+ async def _listen(self, started_listening: asyncio.Event):
+ connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs)
+
+ async with connection:
+ channel = await connection.channel()
+
+ exchange = await channel.declare_exchange(
+ self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True
+ )
+ queue = await channel.declare_queue(exclusive=True, auto_delete=True)
+ await queue.bind(exchange)
+
+ loop = asyncio.get_running_loop()
+ loop.call_soon_threadsafe(started_listening.set)
+
+ async with queue.iterator() as queue_iter:
+ async for message in queue_iter:
+ cancellation_message = model_validate(
+ CancellationMessage, self.serializer.loadb(message.body)
+ )
+
+ for subscriber_queue in self.queues:
+ await subscriber_queue.put(cancellation_message)
+ await message.ack()
diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py
new file mode 100644
index 0000000..1dfe59f
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/in_memory.py
@@ -0,0 +1,41 @@
+import time
+import asyncio
+from typing import Union
+
+from taskiq_cancellation.message import CancellationMessage
+
+from .queue import QueueCancellationNotifier
+
+
+class InMemoryCancellationNotifier(QueueCancellationNotifier):
+ """In memory cancellation notifier used for testing"""
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ # In Python 3.9 queues must be created inside a running loop
+ # Source: https://stackoverflow.com/questions/53724665
+ self.messages: Union[asyncio.Queue[CancellationMessage], None] = None
+
+ async def cancel(self, task_id: str) -> None:
+ if self.messages is None:
+ self.messages = asyncio.Queue()
+
+ timestamp = time.time()
+
+ await self.messages.put(
+ CancellationMessage(task_id=task_id, timestamp=timestamp)
+ )
+
+ async def _listen(self, started_listening: asyncio.Event) -> None:
+ if self.messages is None:
+ self.messages = asyncio.Queue()
+
+ loop = asyncio.get_running_loop()
+ loop.call_soon_threadsafe(started_listening.set)
+
+ while True:
+ message = await self.messages.get()
+
+ for queue in self.queues:
+ await queue.put(message)
diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py
new file mode 100644
index 0000000..bb19211
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/null.py
@@ -0,0 +1,20 @@
+import asyncio
+
+from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent
+
+
+class NullCancellationNotifier(CancellationNotifier):
+ """
+ \"Do nothing\" cancellation notifier
+
+ May be useful if there's no need or ability to use an actual notifier
+ """
+
+ async def cancel(self, task_id: str) -> None:
+ pass
+
+ async def listen_for_cancellation(
+ self, task_id: str, started_listening_event: StartedListeningEvent
+ ) -> None:
+ await started_listening_event.set()
+ await asyncio.sleep(float("+inf"))
diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py
new file mode 100644
index 0000000..16ea174
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/queue.py
@@ -0,0 +1,72 @@
+import abc
+import weakref
+import asyncio
+from typing import Union
+
+from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent
+from taskiq_cancellation.exceptions import TaskCancellationException
+from taskiq_cancellation.message import CancellationMessage
+
+
+class QueueCancellationNotifier(CancellationNotifier):
+ """
+ A helper cancellation notifier that uses one listener to receive cancellation messages and
+ notifies listeners from `listen_for_cancellation` via `asyncio.Queue`
+
+ Requires :func:`_listen` to be implemeted
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.listener_task: Union[asyncio.Task, None] = None
+ self.queues: weakref.WeakSet[asyncio.Queue[CancellationMessage]] = (
+ weakref.WeakSet()
+ )
+ """Set of subscribers' `asyncio.Queue`s to populate when message's received"""
+
+ async def shutdown(self) -> None:
+ if self.listener_task is not None:
+ self.listener_task.cancel()
+ await asyncio.wait([self.listener_task])
+
+ async def listen_for_cancellation(
+ self, task_id: str, started_listening_event: StartedListeningEvent
+ ) -> None:
+ cancellations: asyncio.Queue[CancellationMessage] = asyncio.Queue()
+
+ if self.listener_task is None:
+ await self._create_listener_task()
+
+ await self._subscribe(cancellations)
+ await started_listening_event.set()
+
+ while True:
+ cancellation_message = await cancellations.get()
+
+ if cancellation_message.task_id == task_id:
+ raise TaskCancellationException()
+
+ @abc.abstractmethod
+ async def _listen(self, started_listening: asyncio.Event) -> None:
+ """
+ Listens for cancellation messages and put them into subscribers' `asyncio.Queue`s
+
+ :param started_listening: event to be set when listener is ready to receive messages
+ :type started_listening: asyncio.Event
+ """
+ pass
+
+ async def _create_listener_task(self):
+ if self.listener_task is not None:
+ self.listener_task.cancel()
+
+ started_listening = asyncio.Event()
+ self.listener_task = asyncio.create_task(self._listen(started_listening))
+ await started_listening.wait()
+
+ async def _subscribe(self, queue: asyncio.Queue[CancellationMessage]):
+ self.queues.add(queue)
+
+ async def _unsubsribe(self, queue: asyncio.Queue[CancellationMessage]):
+ self.queues.remove(queue)
diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py
new file mode 100644
index 0000000..fe7746e
--- /dev/null
+++ b/src/taskiq_cancellation/notifiers/redis.py
@@ -0,0 +1,70 @@
+import time
+import asyncio
+
+import redis.asyncio as redis
+from taskiq.compat import model_dump, model_validate
+
+from taskiq_cancellation.message import CancellationMessage
+
+from .queue import QueueCancellationNotifier
+
+
+class PubSubCancellationNotifier(QueueCancellationNotifier):
+ """Cancellation notifier using Redis pub/sub"""
+
+ CHANNEL_NAME = "__taskiq_cancellation_notifications"
+
+ def __init__(self, url: str, **connection_kwargs) -> None:
+ """
+ Creates AioPika notifier
+
+ :param url: url to redis
+ :type url: str
+ :param connection_kwargs: arguments for :ref:`redis.BlockingConnectionPool.from_url`
+ """
+ super().__init__()
+
+ self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs)
+
+ async def cancel(self, task_id: str) -> None:
+ timestamp = time.time()
+
+ async with redis.Redis(connection_pool=self.connection_pool) as conn:
+ await conn.publish(
+ self.CHANNEL_NAME,
+ self.serializer.dumpb(
+ model_dump(
+ CancellationMessage(task_id=task_id, timestamp=timestamp)
+ )
+ ),
+ )
+
+ async def _listen(self, started_listening: asyncio.Event):
+ async with redis.Redis(connection_pool=self.connection_pool) as conn:
+ pubsub = conn.pubsub()
+ await pubsub.subscribe(self.CHANNEL_NAME)
+
+ # started_listening.set()
+ loop = asyncio.get_running_loop()
+ loop.call_soon_threadsafe(started_listening.set)
+
+ while True:
+ message = await pubsub.get_message(
+ ignore_subscribe_messages=True, timeout=None
+ )
+
+ if message is None:
+ continue
+ if message["type"] != "message":
+ continue
+
+ cancellation_message = model_validate(
+ CancellationMessage, self.serializer.loadb(message["data"])
+ )
+ for queue in self.queues:
+ await queue.put(cancellation_message)
+
+ async def shutdown(self) -> None:
+ await super().shutdown()
+ await self.connection_pool.aclose()
+
\ No newline at end of file
diff --git a/src/taskiq_cancellation/integrations/__init__.py b/src/taskiq_cancellation/py.typed
similarity index 100%
rename from src/taskiq_cancellation/integrations/__init__.py
rename to src/taskiq_cancellation/py.typed
diff --git a/src/taskiq_cancellation/state_holders/__init__.py b/src/taskiq_cancellation/state_holders/__init__.py
new file mode 100644
index 0000000..9296f10
--- /dev/null
+++ b/src/taskiq_cancellation/state_holders/__init__.py
@@ -0,0 +1,8 @@
+from .in_memory import InMemoryCancellationStateHolder
+from .null import NullCancellationStateHolder
+
+
+__all__ = [
+ "InMemoryCancellationStateHolder",
+ "NullCancellationStateHolder"
+]
diff --git a/src/taskiq_cancellation/state_holders/in_memory.py b/src/taskiq_cancellation/state_holders/in_memory.py
new file mode 100644
index 0000000..b63277e
--- /dev/null
+++ b/src/taskiq_cancellation/state_holders/in_memory.py
@@ -0,0 +1,16 @@
+from taskiq_cancellation.abc import CancellationStateHolder
+
+
+class InMemoryCancellationStateHolder(CancellationStateHolder):
+ """In memory cancellation state holder used for testing"""
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.state_holder: dict[str, bool] = {}
+
+ async def cancel(self, task_id: str) -> None:
+ self.state_holder[task_id] = True
+
+ async def is_cancelled(self, task_id: str) -> bool:
+ return self.state_holder.get(task_id, False)
diff --git a/src/taskiq_cancellation/state_holders/null.py b/src/taskiq_cancellation/state_holders/null.py
new file mode 100644
index 0000000..b24462d
--- /dev/null
+++ b/src/taskiq_cancellation/state_holders/null.py
@@ -0,0 +1,15 @@
+from taskiq_cancellation.abc import CancellationStateHolder
+
+
+class NullCancellationStateHolder(CancellationStateHolder):
+ """
+ \"Do nothing\" cancellation state holder
+
+ May be useful if there's no need or ability to use an actual state holder
+ """
+
+ async def cancel(self, task_id: str) -> None:
+ pass
+
+ async def is_cancelled(self, task_id: str) -> bool:
+ return False
diff --git a/src/taskiq_cancellation/integrations/redis/state_holder.py b/src/taskiq_cancellation/state_holders/redis.py
similarity index 76%
rename from src/taskiq_cancellation/integrations/redis/state_holder.py
rename to src/taskiq_cancellation/state_holders/redis.py
index 4ca397c..be86f08 100644
--- a/src/taskiq_cancellation/integrations/redis/state_holder.py
+++ b/src/taskiq_cancellation/state_holders/redis.py
@@ -4,9 +4,9 @@
class RedisCancellationStateHolder(CancellationStateHolder):
- def __init__(self, url: str, **connection_kwargs) -> None:
- super().__init__()
- self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs)
+ def __init__(self, url: str, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs)
async def cancel(self, task_id: str) -> None:
async with redis.Redis(connection_pool=self.connection_pool) as conn:
@@ -17,5 +17,10 @@ async def is_cancelled(self, task_id: str) -> bool:
response = await conn.get(self._task_key(task_id))
return bool(response)
+ async def shutdown(self) -> None:
+ await super().shutdown()
+ await self.connection_pool.aclose()
+
def _task_key(self, task_id: str) -> str:
return f"__cancellation_status_{task_id}"
+
diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py
index b1cda03..a128266 100644
--- a/src/taskiq_cancellation/utils.py
+++ b/src/taskiq_cancellation/utils.py
@@ -5,14 +5,16 @@
from collections import OrderedDict
-def combines(wrapped):
+def combines(
+ wrapped: typing.Callable, add_var_parameters: bool = False
+) -> typing.Callable:
"""
Combines wrapped and wrapper functions signatures and type hints
In cases of parameter collision wrapper parameters will be used
Example:
- '''
+ '''
import inspect
def decorator(func):
@@ -20,41 +22,59 @@ def decorator(func):
def wrapper(c: int, *args, **kwargs):
return foo(*args, **kwargs) * c
return wrapper
-
+
@decorator
def foo(a: int, b = "lol"):
return b * a
-
+
print(inspect.signature(foo))
# (a: int, c: int, b='lol', *args, **kwargs)
'''
+
+ :param wrapped: function to be wrapped
+ :type wrapped: Callable
+ :param add_var_parameters: add *args and **kwargs from wrapper to new signature
+ :type add_var_parameters: bool
"""
wrapped_signature: inspect.Signature = inspect.signature(wrapped)
wrapped_type_hints: typing.Dict[str, str] = typing.get_type_hints(wrapped)
-
+
def decorator(wrapper):
wrapper_signature = inspect.signature(wrapper)
wrapper_type_hints = typing.get_type_hints(wrapper)
for param_name in wrapped_signature.parameters.keys():
if param_name in wrapper_signature.parameters.keys():
- logging.warning(f"Parameter {param_name} will be overwritten by wrapper function")
-
- parameters = OrderedDict(wrapped_signature.parameters, **wrapper_signature.parameters)
+ logging.warning(
+ f"Parameter {param_name} will be overwritten by wrapper function"
+ )
+
+ wrapper_parameters = OrderedDict()
+ for name, parameter in wrapper_signature.parameters.items():
+ if not add_var_parameters:
+ if any(
+ (
+ parameter.kind is inspect.Parameter.VAR_POSITIONAL,
+ parameter.kind is inspect.Parameter.VAR_KEYWORD,
+ )
+ ):
+ continue
+ wrapper_parameters[name] = parameter
+
+ parameters = OrderedDict(wrapped_signature.parameters, **wrapper_parameters)
parameters = sorted(
parameters.values(),
- key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0)
+ key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0),
)
new_return_annotation: inspect.Signature
if wrapper_signature.return_annotation is not None:
- new_return_annotation = wrapper_signature.return_annotation
+ new_return_annotation = wrapper_signature.return_annotation
else:
- new_return_annotation = wrapped_signature.return_annotation
+ new_return_annotation = wrapped_signature.return_annotation
new_signature = inspect.Signature(
- parameters=parameters,
- return_annotation=new_return_annotation
+ parameters=parameters, return_annotation=new_return_annotation
)
new_annotations = dict(wrapped_type_hints, **wrapper_type_hints)
@@ -63,9 +83,12 @@ def decorator(wrapper):
wrapper.__signature__ = new_signature
return wrapper
+
return decorator
-__all__ = [
- "combines"
-]
+class StopTaskGroupException(Exception):
+ pass
+
+
+__all__ = ["combines"]
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/integration/aiopika/__init__.py b/tests/integration/aiopika/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/integration/aiopika/conftest.py b/tests/integration/aiopika/conftest.py
new file mode 100644
index 0000000..53e27dc
--- /dev/null
+++ b/tests/integration/aiopika/conftest.py
@@ -0,0 +1,14 @@
+import os
+import pytest
+
+from taskiq import InMemoryBroker
+
+
+@pytest.fixture
+def rabbitmq_url():
+ return os.environ.get("TEST_RABBITMQ_URL", "amqp://guest:guest@localhost:5672")
+
+
+@pytest.fixture
+def broker():
+ return InMemoryBroker()
diff --git a/tests/integration/aiopika/test_notifier.py b/tests/integration/aiopika/test_notifier.py
new file mode 100644
index 0000000..11e6553
--- /dev/null
+++ b/tests/integration/aiopika/test_notifier.py
@@ -0,0 +1,15 @@
+import pytest
+
+from taskiq_cancellation.notifiers.aiopika import AioPikaNotifier
+
+from ..common.cancellations import run_notifier_cancellation_test
+
+
+@pytest.fixture
+def notifier(rabbitmq_url):
+ return AioPikaNotifier(url=rabbitmq_url)
+
+
+@pytest.mark.asyncio
+async def test_cancellation(notifier: AioPikaNotifier):
+ await run_notifier_cancellation_test(notifier)
diff --git a/tests/integration/common/cancellations.py b/tests/integration/common/cancellations.py
new file mode 100644
index 0000000..e495604
--- /dev/null
+++ b/tests/integration/common/cancellations.py
@@ -0,0 +1,94 @@
+import sys
+import uuid
+import pytest
+import asyncio
+
+from taskiq import InMemoryBroker
+
+from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder
+from taskiq_cancellation.backends.modular import ModularCancellationBackend
+from taskiq_cancellation.notifiers.null import NullCancellationNotifier
+from taskiq_cancellation.state_holders.null import NullCancellationStateHolder
+from taskiq_cancellation.exceptions import TaskCancellationException
+
+if sys.version_info >= (3, 11):
+ from asyncio import timeout
+else:
+ from async_timeout import timeout
+
+
+
+async def run_backend_cancellation_test(backend: CancellationBackend):
+ broker = InMemoryBroker()
+ backend = backend.with_broker(broker)
+
+ @broker.task
+ @backend.cancellable
+ async def test_task():
+ pass
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ await backend.cancel(task.task_id)
+
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+
+ await broker.shutdown()
+
+
+async def run_notifier_cancellation_test(notifier: CancellationNotifier):
+ broker = InMemoryBroker()
+ backend = ModularCancellationBackend(
+ NullCancellationStateHolder(),
+ notifier
+ ).with_broker(broker)
+
+ task_started = asyncio.Event()
+
+ @broker.task
+ @backend.cancellable
+ async def test_task():
+ task_started.set()
+ await asyncio.sleep(0.2)
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ async with timeout(1):
+ await task_started.wait()
+
+ await backend.cancel(task.task_id)
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+
+ await broker.shutdown()
+
+
+async def run_state_holder_cancellation_test(state_holder: CancellationStateHolder):
+ broker = InMemoryBroker()
+ backend = ModularCancellationBackend(
+ state_holder,
+ NullCancellationNotifier()
+ ).with_broker(broker)
+
+ @broker.task
+ @backend.cancellable
+ async def test_task():
+ # This task is supposed to never start
+ assert False
+
+ await broker.startup()
+
+ task_id = str(uuid.uuid4())
+ await backend.cancel(task_id)
+ task = await test_task.kicker().with_task_id(task_id).kiq()
+
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+
+ await broker.shutdown()
diff --git a/tests/integration/redis/__init__.py b/tests/integration/redis/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/integration/redis/conftest.py b/tests/integration/redis/conftest.py
new file mode 100644
index 0000000..1823f42
--- /dev/null
+++ b/tests/integration/redis/conftest.py
@@ -0,0 +1,14 @@
+import os
+import pytest
+
+from taskiq import InMemoryBroker
+
+
+@pytest.fixture
+def redis_url():
+ return os.environ.get("TEST_REDIS_URL", "redis://localhost:6379")
+
+
+@pytest.fixture
+def broker():
+ return InMemoryBroker()
diff --git a/tests/integration/redis/test_backend.py b/tests/integration/redis/test_backend.py
new file mode 100644
index 0000000..a5bde89
--- /dev/null
+++ b/tests/integration/redis/test_backend.py
@@ -0,0 +1,15 @@
+import pytest
+
+from taskiq_cancellation.backends.redis import RedisCancellationBackend
+
+from ..common.cancellations import run_backend_cancellation_test
+
+
+@pytest.fixture
+def redis_backend(redis_url, broker):
+ return RedisCancellationBackend(url=redis_url).with_broker(broker)
+
+
+@pytest.mark.asyncio
+async def test_cancellation(redis_backend: RedisCancellationBackend):
+ await run_backend_cancellation_test(redis_backend)
diff --git a/tests/integration/redis/test_pubsub.py b/tests/integration/redis/test_pubsub.py
new file mode 100644
index 0000000..b49720c
--- /dev/null
+++ b/tests/integration/redis/test_pubsub.py
@@ -0,0 +1,15 @@
+import pytest
+
+from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier
+
+from ..common.cancellations import run_notifier_cancellation_test
+
+
+@pytest.fixture
+def pubsub_notifier(redis_url):
+ return PubSubCancellationNotifier(url=redis_url)
+
+
+@pytest.mark.asyncio
+async def test_cancellation(pubsub_notifier: PubSubCancellationNotifier):
+ await run_notifier_cancellation_test(pubsub_notifier)
diff --git a/tests/integration/redis/test_state_holder.py b/tests/integration/redis/test_state_holder.py
new file mode 100644
index 0000000..f8be37b
--- /dev/null
+++ b/tests/integration/redis/test_state_holder.py
@@ -0,0 +1,15 @@
+import pytest
+
+from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder
+
+from ..common.cancellations import run_state_holder_cancellation_test
+
+
+@pytest.fixture
+def state_holder(redis_url):
+ return RedisCancellationStateHolder(url=redis_url)
+
+
+@pytest.mark.asyncio
+async def test_cancellation(state_holder: RedisCancellationStateHolder):
+ await run_state_holder_cancellation_test(state_holder)
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py
new file mode 100644
index 0000000..459492d
--- /dev/null
+++ b/tests/unit/test_backend.py
@@ -0,0 +1,50 @@
+import pytest
+import inspect
+
+from taskiq_cancellation.abc.backend import CancellationBackend
+from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend
+
+
+@pytest.fixture
+def backend():
+ return InMemoryCancellationBackend()
+
+
+@pytest.mark.asyncio
+async def test_decorator_without_parentesis(backend: CancellationBackend):
+ """Tests that cancellable decorator works without parentesis"""
+
+ @backend.cancellable
+ async def test_task():
+ pass
+
+ task = test_task()
+ assert inspect.iscoroutine(task)
+ await task
+
+
+@pytest.mark.asyncio
+async def test_decorator_with_parentesis(backend: CancellationBackend):
+ """Tests that cancellable decorator works with parentesis"""
+
+ @backend.cancellable()
+ async def test_task():
+ pass
+
+ task = test_task()
+ assert inspect.iscoroutine(task)
+ await task
+
+
+@pytest.mark.asyncio
+async def test_decorator_with_sync_function(backend: CancellationBackend):
+ """
+ Tests that cancellable decorator doesn't work with synchronous functions
+
+ To launch a synchronous function we need to know how to do it and only Receiver knows that
+ """
+
+ with pytest.raises(ValueError):
+ @backend.cancellable
+ def test_task():
+ pass
diff --git a/tests/unit/test_cancellation.py b/tests/unit/test_cancellation.py
new file mode 100644
index 0000000..892e988
--- /dev/null
+++ b/tests/unit/test_cancellation.py
@@ -0,0 +1,237 @@
+import sys
+import pytest
+import asyncio
+
+if sys.version_info >= (3, 11):
+ from asyncio import timeout
+else:
+ from async_timeout import timeout
+
+import anyio
+from taskiq import AsyncBroker, InMemoryBroker
+
+from taskiq_cancellation.abc.backend import CancellationBackend, CancellationType
+from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend
+from taskiq_cancellation.exceptions import TaskCancellationException
+
+
+@pytest.fixture
+def broker():
+ return InMemoryBroker()
+
+
+@pytest.fixture
+def backend(broker: AsyncBroker):
+ return InMemoryCancellationBackend().with_broker(broker)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE)
+)
+async def test_task_direct_call(
+ broker: AsyncBroker,
+ backend: CancellationBackend,
+ cancellation_type: CancellationType,
+):
+ """Tests that cancellable function can be called directly"""
+
+ @broker.task
+ @backend.cancellable(cancellation_type=cancellation_type)
+ async def test_task():
+ return True
+
+ result = await test_task()
+ assert result
+
+
+class TestTaskSuccess:
+ types = [CancellationType.LEVEL]
+ if sys.version_info >= (3, 11):
+ types.append(CancellationType.EDGE)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(("cancellation_type"), types)
+ @staticmethod
+ async def test_task_success(
+ broker: AsyncBroker,
+ backend: CancellationBackend,
+ cancellation_type: CancellationType,
+ ):
+ """Tests that cancellable task can successfully finish"""
+
+ @broker.task
+ @backend.cancellable(cancellation_type=cancellation_type)
+ async def test_task():
+ await asyncio.sleep(0.1)
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ result = await task.wait_result()
+ assert result.is_err is False
+
+ await broker.shutdown()
+
+
+class TestTaskCancellation:
+ types = [CancellationType.LEVEL]
+ if sys.version_info > (3, 11):
+ types.append(CancellationType.EDGE)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(("cancellation_type"), types)
+ @staticmethod
+ async def test_task_cancellation(
+ broker: AsyncBroker,
+ backend: CancellationBackend,
+ cancellation_type: CancellationType,
+ ):
+ """Tests that cancellable task can successfully cancel"""
+
+ started_event = asyncio.Event()
+
+ @broker.task
+ @backend.cancellable(cancellation_type=cancellation_type)
+ async def test_task():
+ with pytest.raises(anyio.get_cancelled_exc_class()):
+ started_event.set()
+ await asyncio.sleep(0.2)
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ assert await task.is_ready() is False
+
+ async with timeout(1):
+ await started_event.wait()
+ await backend.cancel(task.task_id)
+
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+
+ await broker.shutdown()
+
+
+class TestEdgeCancellation:
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(
+ sys.version_info >= (3, 11),
+ reason="Edge cancellation is currently not supported for Python <3.11",
+ )
+ async def test_non_supported(
+ self, broker: AsyncBroker, backend: CancellationBackend
+ ):
+ """Tests that edge cancellation raises NotImplementedError in Python <3.11"""
+
+ @broker.task
+ @backend.cancellable(cancellation_type=CancellationType.EDGE)
+ async def test_task():
+ await asyncio.sleep(0.1)
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+
+ with pytest.raises(NotImplementedError):
+ result = await task.wait_result()
+ result.raise_for_error()
+
+ await broker.shutdown()
+
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(
+ sys.version_info < (3, 11),
+ reason="Edge cancellation is currently not supported for Python <3.11",
+ )
+ async def test_cancellation_interception(
+ self, broker: AsyncBroker, backend: CancellationBackend
+ ):
+ """
+ Tests that tasks can capture task cancellation
+
+ asyncio raises asyncio.CancelledError only once. Task can intercept that to do cleanup,
+ but also can just ignore the cancellation request.
+ Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation
+ """
+
+ cancelled_for_second_time = False
+ task_started = asyncio.Event()
+
+ @broker.task
+ @backend.cancellable(cancellation_type=CancellationType.EDGE)
+ async def test_task():
+ nonlocal cancelled_for_second_time
+
+ try:
+ task_started.set()
+ await asyncio.sleep(0.5)
+ except asyncio.CancelledError:
+ try:
+ await asyncio.sleep(0)
+ except asyncio.CancelledError:
+ cancelled_for_second_time = True
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ assert await task.is_ready() is False
+
+ async with timeout(1):
+ await task_started.wait()
+ await backend.cancel(task.task_id)
+
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+ assert cancelled_for_second_time is False
+
+ await broker.shutdown()
+
+
+class TestLevelCancellation:
+ @pytest.mark.asyncio
+ async def test_repeated_cancellation(
+ self, broker: AsyncBroker, backend: CancellationBackend
+ ):
+ """
+ Tests that task will have multiple cancellation exceptions
+
+ anyio raises cancellation exception on every await
+ Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics
+ """
+
+ cancelled_for_second_time = False
+ started_event = asyncio.Event()
+
+ @broker.task
+ @backend.cancellable(cancellation_type=CancellationType.LEVEL)
+ async def test_task():
+ nonlocal cancelled_for_second_time
+
+ try:
+ started_event.set()
+ await asyncio.sleep(1)
+ except anyio.get_cancelled_exc_class():
+ # anyio cancels on any await after scope's cancellation
+ try:
+ await asyncio.sleep(0)
+ except anyio.get_cancelled_exc_class():
+ cancelled_for_second_time = True
+
+ await broker.startup()
+
+ task = await test_task.kiq()
+ assert await task.is_ready() is False
+
+ async with timeout(1):
+ await started_event.wait()
+ await backend.cancel(task.task_id)
+
+ with pytest.raises(TaskCancellationException):
+ result = await task.wait_result()
+ result.raise_for_error()
+ assert cancelled_for_second_time is True
+
+ await broker.shutdown()