diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..0e31ea4 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,33 @@ +name: Lint + +on: + pull_request: + branches: [develop] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: 3.14 + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment + run: uv venv .venv && source .venv/bin/activate + + - name: Install modules + run: uv sync + + - name: Check code style + run: uv run ruff check + + - name: Check static typing + run: uv run mypy . diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 0000000..3a45a80 --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,35 @@ +name: Testing + +on: + pull_request: + branches: [develop] + +jobs: + run-tests: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment + run: uv venv .venv && source .venv/bin/activate + + - name: Install modules + run: uv sync + + - name: Run tests for Python ${{ matrix.python-version }} + run: uv run pytest diff --git a/examples/counter/main.py b/examples/counter/main.py index 5facc63..a148aae 100644 --- a/examples/counter/main.py +++ b/examples/counter/main.py @@ -6,7 +6,7 @@ url = "redis://localhost" broker = PubSubBroker(url).with_result_backend(RedisAsyncResultBackend(url)) -cancellation_backend = RedisCancellationBackend(url) +cancellation_backend = RedisCancellationBackend(url).with_broker(broker) @broker.task @@ -20,11 +20,14 @@ async def count(up_to: int): async def main(): await broker.startup() + print("Sending task and waiting 5 seconds...") task = await count.kiq(5) await asyncio.sleep(5) + print("Sending task and waiting 2.5 seconds...") task = await count.kiq(5) await asyncio.sleep(2.5) + print("Canceling task...") await cancellation_backend.cancel(task.task_id) await broker.shutdown() diff --git a/pyproject.toml b/pyproject.toml index 5c1c6cf..869d263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,28 +1,25 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "taskiq-cancellation" dynamic = ["version"] -description = 'Task cancellation mechanism for taskiq' +description = 'Task cancellation for taskiq' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "MIT" keywords = ["taskiq", "cancellation"] authors = [{ name = "Alexander Starikov", email = "acherryjam@gmail.com" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["taskiq"] +dependencies = ["taskiq", "typing-extensions>=4.13.2"] [project.optional-dependencies] redis = ["redis~=3.0"] @@ -33,6 +30,10 @@ Documentation = "https://github.com/ACherryJam/taskiq-cancellation#readme" Issues = "https://github.com/ACherryJam/taskiq-cancellation/issues" Source = "https://github.com/ACherryJam/taskiq-cancellation" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [tool.hatch.version] path = "src/taskiq_cancellation/__about__.py" @@ -41,6 +42,11 @@ extra-dependencies = ["mypy>=1.0.0"] [tool.hatch.envs.types.scripts] check = "mypy --install-types --non-interactive {args:src/taskiq_cancellation tests}" +[tool.mypy] +ignore_missing_imports = true +exclude = ["examples"] + + [tool.coverage.run] source_pkgs = ["taskiq_cancellation", "tests"] branch = true @@ -56,3 +62,11 @@ tests = ["tests", "*/taskiq-cancellation/tests"] [tool.coverage.report] exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[dependency-groups] +dev = [ + "mypy>=1.14.1", + "pytest>=8.3.5", + "pytest-asyncio>=0.24.0", + "ruff>=0.14.4", +] diff --git a/src/taskiq_cancellation/__init__.py b/src/taskiq_cancellation/__init__.py index 4ad9e80..57864e4 100644 --- a/src/taskiq_cancellation/__init__.py +++ b/src/taskiq_cancellation/__init__.py @@ -1,8 +1,8 @@ -from .abc import * -from .modular import ModularCancellationBackend +from .abc import CancellationBackend +from .backends.modular import ModularCancellationBackend __all__ = [ + "CancellationBackend", "ModularCancellationBackend" ] - diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py index cc8b52c..1decfa6 100644 --- a/src/taskiq_cancellation/abc/__init__.py +++ b/src/taskiq_cancellation/abc/__init__.py @@ -3,8 +3,4 @@ from .state_holder import CancellationStateHolder -__all__ = [ - "CancellationBackend", - "CancellationNotifier", - "CancellationStateHolder" -] +__all__ = ["CancellationBackend", "CancellationNotifier", "CancellationStateHolder"] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 4f5ff15..c3befb9 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,14 +1,14 @@ import abc import asyncio -import traceback -from typing import Callable, Annotated, ParamSpec, TypeVar, Awaitable +from typing import Callable, Annotated, TypeVar, Awaitable +from typing_extensions import ParamSpec, Self import anyio from anyio.abc import TaskStatus -from taskiq import Context, TaskiqDepends +from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState -from ..utils import combines -from ..exceptions import TaskCancellationException +from taskiq_cancellation.utils import combines +from taskiq_cancellation.exceptions import TaskCancellationException P = ParamSpec("P") @@ -16,42 +16,145 @@ class CancellationBackend(abc.ABC): - @abc.abstractmethod + """ + Base class for cancellation backend + """ + def __init__(self) -> None: + super().__init__() + + self.broker: AsyncBroker | None = None + + @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: + """ + Checks if a task with task id of *task_id* is set to be cancelled + + :param task_id: task id to check + :type task_id: str + :returns: task cancellation state + :rtype: bool + """ pass @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Cancels a task with task id of *task_id* + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def listen_for_cancellation( self, task_id: str, started_listening_task_status: TaskStatus ) -> None: + """ + Listens for cancellation messages and raises :ref:`TaskCancellationException` when + receives :ref:`CancellationMessage` with same id as *task_id*. + + This function is used in :func:`cancellable` decorator. + Call `started_listening_task_status.started()` when the listener is ready + to receive messages. + + :param task_id: id of task that will be listened for + :type task_id: str + :param started_listening_task_status: + :type started_listening_task_status: anyio.abc.TaskStatus + """ pass - + + async def startup(self) -> None: + """ + Starts up cancellation backend + + Triggered only if backend has a broker set. To set a broker use :func:`with_broker`. + """ + pass + + async def shutdown(self) -> None: + """Shuts down cancellation backend + + Triggered only if backend has a broker set. To set a broker use :ref:`with_broker`. + """ + pass + + def with_broker(self, broker: AsyncBroker) -> Self: + """ + Set a broker and return updated cancellation backend + + Sets up startup and shutdown event handlers for backend's startup + and shutdown methods respectfully + + :param broker: new broker + :type broker: AsyncBroker + :returns: self + """ + if self.broker is not None: + self.broker.event_handlers[TaskiqEvents.CLIENT_STARTUP].remove( + self._broker_startup_handler + ) + self.broker.event_handlers[TaskiqEvents.WORKER_STARTUP].remove( + self._broker_startup_handler + ) + self.broker.event_handlers[TaskiqEvents.CLIENT_SHUTDOWN].remove( + self._broker_shutdown_handler + ) + self.broker.event_handlers[TaskiqEvents.WORKER_SHUTDOWN].remove( + self._broker_shutdown_handler + ) + + self.broker = broker + self.broker.add_event_handler( + TaskiqEvents.CLIENT_STARTUP, self._broker_startup_handler + ) + self.broker.add_event_handler( + TaskiqEvents.WORKER_STARTUP, self._broker_startup_handler + ) + self.broker.add_event_handler( + TaskiqEvents.CLIENT_SHUTDOWN, self._broker_shutdown_handler + ) + self.broker.add_event_handler( + TaskiqEvents.WORKER_SHUTDOWN, self._broker_shutdown_handler + ) + + return self + def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - # Executor type depends on receiver configuration which we - # can't access in any way + """ + Decorator that makes funcion cancellable + + This decorator makes a new function that creates two tasks in :ref:`anyio.TaskGroup`: + 1. Cancellation message listener (uses :ref:`listen_for_cancellation`) + 2. Wrapped function + + - Returns function's result/exception if it finishes successfully/unsuccessfully + - Raises :ref:`TaskCancellationException` if listener task receives cancellation message + - If listener task raises an exception, task is cancelled and exception is propogated upwards + + :param task: Task function to wrap + :returns: Cancellable task function + """ + # Executor type depends on receiver configuration which we can't accessed in any way if not asyncio.iscoroutinefunction(task): raise ValueError("Can't cancel synchronous function") - + @combines(task) async def wrapper( - *args, - __taskiq_context: Annotated[Context, TaskiqDepends()], - **kwargs + *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs ): task_id = __taskiq_context.message.task_id result = None listener_exception: Exception | None = None task_exception: Exception | None = None - cancelled_by_request: bool = False + cancelled_by_request: bool = False async with anyio.create_task_group() as group: + async def listen_for_cancellation( - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ): nonlocal listener_exception, cancelled_by_request @@ -65,7 +168,7 @@ async def listen_for_cancellation( listener_exception = e finally: group.cancel_scope.cancel() - + async def call_task(): nonlocal result, task_exception @@ -78,15 +181,15 @@ async def call_task(): finally: group.cancel_scope.cancel() - # Listen before checking for cancellation in database so - # the message won't get lost in non-persistent queues + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues await group.start(listen_for_cancellation) if await self.is_cancelled(task_id): cancelled_by_request = True group.cancel_scope.cancel() else: group.start_soon(call_task) - + if task_exception is not None: raise task_exception elif cancelled_by_request: @@ -95,4 +198,11 @@ async def call_task(): raise listener_exception else: return result + return wrapper + + async def _broker_startup_handler(self, _: TaskiqState) -> None: + await self.startup() + + async def _broker_shutdown_handler(self, _: TaskiqState) -> None: + await self.shutdown() diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 0faaa91..3ba408f 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -6,15 +6,44 @@ class CancellationNotifier(abc.ABC): + """Receives cancellation messages and notifies listeners of these messages""" + def __init__(self, serializer: TaskiqSerializer = JSONSerializer()): self.serializer = serializer - + + async def startup(self) -> None: + """Starts up cancellation notifier""" + pass + + async def shutdown(self) -> None: + """Shuts down cancellation notifier""" + pass + @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Sends a cancellation message of a task with task id of *task_id* + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def listen_for_cancellation( self, task_id: str, started_listening_task_status: TaskStatus ) -> None: + """ + Listens for cancellation messages and raises :ref:`TaskCancellationException` when + receives :ref:`CancellationMessage` with same id as *task_id*. + + This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. + Call `started_listening_task_status.started()` when the listener is ready + to receive messages. + + :param task_id: id of task that will be listened for + :type task_id: str + :param started_listening_task_status: + :type started_listening_task_status: anyio.abc.TaskStatus + """ pass diff --git a/src/taskiq_cancellation/abc/state_holder.py b/src/taskiq_cancellation/abc/state_holder.py index 45df346..8416440 100644 --- a/src/taskiq_cancellation/abc/state_holder.py +++ b/src/taskiq_cancellation/abc/state_holder.py @@ -2,10 +2,34 @@ class CancellationStateHolder(abc.ABC): + """Holds cancellation state of Taskiq tasks""" + @abc.abstractmethod async def cancel(self, task_id: str) -> None: + """ + Sets a state of task with task id of *task_id* to be cancelled + + :param task_id: id of the task to cancel + :type task_id: str + """ pass @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: + """ + Checks if a task with task id of *task_id* is set to be cancelled + + :param task_id: task id to check + :type task_id: str + :returns: task cancellation state + :rtype: bool + """ + pass + + async def startup(self) -> None: + """Starts up cancellation state holder""" + pass + + async def shutdown(self) -> None: + """Shuts down cancellation state holder""" pass diff --git a/src/taskiq_cancellation/integrations/__init__.py b/src/taskiq_cancellation/backends/__init__.py similarity index 100% rename from src/taskiq_cancellation/integrations/__init__.py rename to src/taskiq_cancellation/backends/__init__.py diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py new file mode 100644 index 0000000..ee5e385 --- /dev/null +++ b/src/taskiq_cancellation/backends/in_memory.py @@ -0,0 +1,12 @@ +from taskiq_cancellation.notifiers.in_memory import InMemoryCancellationNotifier +from taskiq_cancellation.state_holders.in_memory import InMemoryCancellationStateHolder + +from .modular import ModularCancellationBackend + + +class InMemoryCancellationBackend(ModularCancellationBackend): + def __init__(self, **kwargs): + super().__init__( + state_holder=InMemoryCancellationStateHolder(**kwargs), + notifier=InMemoryCancellationNotifier(**kwargs) + ) diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py new file mode 100644 index 0000000..2d87dfe --- /dev/null +++ b/src/taskiq_cancellation/backends/modular.py @@ -0,0 +1,46 @@ +from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder + +import anyio +from anyio.abc import TaskStatus + + +class ModularCancellationBackend(CancellationBackend): + """ + Modular cancellation backend made up of :class:`CancellationStateHolder` + and :class:`CancellationNotifier` + + - `CancellationStateHolder` stores cancellation state and blocks the task from being run. + - `CancellationNotifier` receives cancellation messages and cancels already running tasks. + """ + def __init__( + self, state_holder: CancellationStateHolder, notifier: CancellationNotifier + ): + super().__init__() + + self.notifier: CancellationNotifier = notifier + self.state_holder: CancellationStateHolder = state_holder + + async def is_cancelled(self, task_id: str) -> bool: + return await self.state_holder.is_cancelled(task_id) + + async def cancel(self, task_id: str): + async with anyio.create_task_group() as group: + group.start_soon(self.state_holder.cancel, task_id) + group.start_soon(self.notifier.cancel, task_id) + + async def listen_for_cancellation( + self, task_id: str, started_listening_task_status: TaskStatus[None] + ): + await self.notifier.listen_for_cancellation( + task_id, started_listening_task_status + ) + + async def startup(self) -> None: + await super().startup() + await self.state_holder.startup() + await self.notifier.startup() + + async def shutdown(self) -> None: + await super().shutdown() + await self.state_holder.shutdown() + await self.notifier.shutdown() diff --git a/src/taskiq_cancellation/backends/redis.py b/src/taskiq_cancellation/backends/redis.py new file mode 100644 index 0000000..3870e32 --- /dev/null +++ b/src/taskiq_cancellation/backends/redis.py @@ -0,0 +1,12 @@ +from taskiq_cancellation.backends.modular import ModularCancellationBackend + +from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder + + +class RedisCancellationBackend(ModularCancellationBackend): + def __init__(self, url: str, **kwargs) -> None: + super().__init__( + RedisCancellationStateHolder(url, **kwargs), + PubSubCancellationNotifier(url, **kwargs), + ) diff --git a/src/taskiq_cancellation/integrations/aiopika/__init__.py b/src/taskiq_cancellation/integrations/aiopika/__init__.py deleted file mode 100644 index 9d0488c..0000000 --- a/src/taskiq_cancellation/integrations/aiopika/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .notifier import AioPikaNotifier - - -__all__ = [ - "AioPikaNotifier" -] diff --git a/src/taskiq_cancellation/integrations/aiopika/notifier.py b/src/taskiq_cancellation/integrations/aiopika/notifier.py deleted file mode 100644 index b79ad78..0000000 --- a/src/taskiq_cancellation/integrations/aiopika/notifier.py +++ /dev/null @@ -1,65 +0,0 @@ -import time - -import aio_pika -from anyio.abc import TaskStatus -from taskiq.abc.serializer import TaskiqSerializer -from taskiq.serializers import JSONSerializer -from taskiq.compat import model_dump, model_validate - -from taskiq_cancellation.abc import CancellationNotifier -from taskiq_cancellation.exceptions import TaskCancellationException -from taskiq_cancellation.message import CancellationMessage - - -class AioPikaNotifier(CancellationNotifier): - EXCHANGE_NAME = "__taskiq_cancellation" - - def __init__(self, url: str, serializer: TaskiqSerializer = JSONSerializer()): - super().__init__(serializer) - - self.url: str = url - - async def cancel(self, task_id: str) -> None: - timestamp = time.time() - connection = await aio_pika.connect_robust( - self.url - ) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - - await exchange.publish( - aio_pika.Message(body=self.serializer.dumpb( - model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp)) - )), - routing_key="" - ) - - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus - ) -> None: - connection = await aio_pika.connect_robust( - self.url - ) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - queue = await channel.declare_queue(exclusive=True) - await queue.bind(exchange) - - started_listening_task_status.started() - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - cancellation_message = model_validate( - CancellationMessage, - self.serializer.loadb(message.body) - ) - - if cancellation_message.task_id == task_id: - raise TaskCancellationException() - await message.ack() diff --git a/src/taskiq_cancellation/integrations/redis/__init__.py b/src/taskiq_cancellation/integrations/redis/__init__.py deleted file mode 100644 index fe38b68..0000000 --- a/src/taskiq_cancellation/integrations/redis/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .notifier import PubSubCancellationNotifier -from .state_holder import RedisCancellationStateHolder -from .backend import RedisCancellationBackend - - -__all__ = [ - "PubSubCancellationNotifier", - "RedisCancellationStateHolder", - "RedisCancellationBackend" -] diff --git a/src/taskiq_cancellation/integrations/redis/backend.py b/src/taskiq_cancellation/integrations/redis/backend.py deleted file mode 100644 index ad038d1..0000000 --- a/src/taskiq_cancellation/integrations/redis/backend.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Type - -from taskiq_cancellation.modular import ModularCancellationBackend - -from .notifier import PubSubCancellationNotifier -from .state_holder import RedisCancellationStateHolder - - -class RedisCancellationBackend(ModularCancellationBackend): - def __init__( - self, - url: str, - state_holder: Type = RedisCancellationStateHolder, - notifier: Type = PubSubCancellationNotifier, - **connection_kwargs - ) -> None: - super().__init__( - state_holder(url, **connection_kwargs), - PubSubCancellationNotifier(url, **connection_kwargs) - ) diff --git a/src/taskiq_cancellation/modular.py b/src/taskiq_cancellation/modular.py deleted file mode 100644 index bead650..0000000 --- a/src/taskiq_cancellation/modular.py +++ /dev/null @@ -1,27 +0,0 @@ -from .abc import * - -import anyio -from anyio.abc import TaskStatus - - -class ModularCancellationBackend(CancellationBackend): - def __init__( - self, - state_holder: CancellationStateHolder, - notifier: CancellationNotifier - ): - self.notifier: CancellationNotifier = notifier - self.state_holder: CancellationStateHolder = state_holder - - async def is_cancelled(self, task_id: str) -> bool: - return await self.state_holder.is_cancelled(task_id) - - async def cancel(self, task_id: str): - async with anyio.create_task_group() as group: - group.start_soon(self.state_holder.cancel, task_id) - group.start_soon(self.notifier.cancel, task_id) - - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus[None] - ): - await self.notifier.listen_for_cancellation(task_id, started_listening_task_status) diff --git a/src/taskiq_cancellation/notifiers/__init__.py b/src/taskiq_cancellation/notifiers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py new file mode 100644 index 0000000..8e55ab8 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -0,0 +1,61 @@ +import time +import asyncio + +import aio_pika +from taskiq.compat import model_dump, model_validate + +from taskiq_cancellation.message import CancellationMessage + +from .queue import QueueCancellationNotifier + + +class AioPikaNotifier(QueueCancellationNotifier): + EXCHANGE_NAME = "__taskiq_cancellation" + + def __init__(self, url: str, **kwargs): + super().__init__(**kwargs) + + self.url: str = url + + async def cancel(self, task_id: str) -> None: + timestamp = time.time() + connection = await aio_pika.connect_robust(self.url) + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + + await exchange.publish( + aio_pika.Message( + body=self.serializer.dumpb( + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) + ) + ), + routing_key="", + ) + + async def _listen(self, started_listening: asyncio.Event): + connection = await aio_pika.connect_robust(self.url) + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + queue = await channel.declare_queue(exclusive=True, auto_delete=True) + await queue.bind(exchange) + + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + cancellation_message = model_validate( + CancellationMessage, self.serializer.loadb(message.body) + ) + + for queue in self.queues: + await queue.put(cancellation_message) + await message.ack() diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py new file mode 100644 index 0000000..748a129 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/in_memory.py @@ -0,0 +1,35 @@ +import time +import asyncio + +from taskiq_cancellation.message import CancellationMessage + +from .queue import QueueCancellationNotifier + + +class InMemoryCancellationNotifier(QueueCancellationNotifier): + """In memory cancellation notifier used for testing""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.messages: asyncio.Queue[CancellationMessage] = asyncio.Queue() + + async def cancel(self, task_id: str) -> None: + timestamp = time.time() + + await self.messages.put( + CancellationMessage( + task_id=task_id, + timestamp=timestamp + ) + ) + + async def _listen(self, started_listening: asyncio.Event) -> None: + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + while True: + message = await self.messages.get() + + for queue in self.queues: + await queue.put(message) diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py new file mode 100644 index 0000000..9740012 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/null.py @@ -0,0 +1,19 @@ +import asyncio + +from anyio.abc import TaskStatus +from taskiq_cancellation.abc.notifier import CancellationNotifier + + +class NullCancellationNotifier(CancellationNotifier): + """ + \"Do nothing\" cancellation notifier + + May be useful if there's no need or ability to use an actual notifier + """ + + async def cancel(self, task_id: str) -> None: + pass + + async def listen_for_cancellation(self, task_id: str, started_listening_task_status: TaskStatus) -> None: + started_listening_task_status.started() + await asyncio.sleep(float("+inf")) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py new file mode 100644 index 0000000..3190501 --- /dev/null +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -0,0 +1,71 @@ +import abc +import weakref +import asyncio + +from anyio.abc import TaskStatus + +from taskiq_cancellation.abc import CancellationNotifier +from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.message import CancellationMessage + + +class QueueCancellationNotifier(CancellationNotifier): + """ + A helper cancellation notifier that uses one listener to receive cancellation messages and + notifies listeners from `listen_for_cancellation` via `asyncio.Queue` + + Requires :func:`_listen` to be implemeted + """ + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.listener_task: asyncio.Task | None = None + self.queues: weakref.WeakSet[asyncio.Queue[CancellationMessage]] = ( + weakref.WeakSet() + ) + """Set of subscribers' `asyncio.Queue`s to populate when message's received""" + + async def shutdown(self) -> None: + if self.listener_task is not None: + self.listener_task.cancel() + + async def listen_for_cancellation( + self, task_id: str, started_listening_task_status: TaskStatus + ) -> None: + cancellations: asyncio.Queue[CancellationMessage] = asyncio.Queue() + + if self.listener_task is None: + await self._create_listener_task() + + await self._subscribe(cancellations) + started_listening_task_status.started() + + while True: + cancellation_message = await cancellations.get() + + if cancellation_message.task_id == task_id: + raise TaskCancellationException() + + @abc.abstractmethod + async def _listen(self, started_listening: asyncio.Event) -> None: + """ + Listens for cancellation messages and put them into subscribers' `asyncio.Queue`s + + :param started_listening: event to be set when listener is ready to receive messages + :type started_listening: asyncio.Event + """ + pass + + async def _create_listener_task(self): + if self.listener_task is not None: + self.listener_task.cancel() + + started_listening = asyncio.Event() + self.listener_task = asyncio.create_task(self._listen(started_listening)) + await started_listening.wait() + + async def _subscribe(self, queue: asyncio.Queue[CancellationMessage]): + self.queues.add(queue) + + async def _unsubsribe(self, queue: asyncio.Queue[CancellationMessage]): + self.queues.remove(queue) diff --git a/src/taskiq_cancellation/integrations/redis/notifier.py b/src/taskiq_cancellation/notifiers/redis.py similarity index 50% rename from src/taskiq_cancellation/integrations/redis/notifier.py rename to src/taskiq_cancellation/notifiers/redis.py index 5ed24fb..2efa7c2 100644 --- a/src/taskiq_cancellation/integrations/redis/notifier.py +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -1,43 +1,48 @@ import time +import asyncio -from anyio.abc import TaskStatus import redis.asyncio as redis from taskiq.compat import model_dump, model_validate -from taskiq_cancellation.abc import CancellationNotifier -from taskiq_cancellation.exceptions import TaskCancellationException from taskiq_cancellation.message import CancellationMessage +from .queue import QueueCancellationNotifier -class PubSubCancellationNotifier(CancellationNotifier): + +class PubSubCancellationNotifier(QueueCancellationNotifier): CHANNEL_NAME = "__taskiq_cancellation_notifications" - def __init__(self, url: str, **connection_kwargs) -> None: - super().__init__() - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) + def __init__(self, url: str, **kwargs) -> None: + super().__init__(**kwargs) + + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs) async def cancel(self, task_id: str) -> None: timestamp = time.time() async with redis.Redis(connection_pool=self.connection_pool) as conn: await conn.publish( - self.CHANNEL_NAME, + self.CHANNEL_NAME, self.serializer.dumpb( - model_dump(CancellationMessage(task_id=task_id, timestamp=timestamp)) - ) + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) + ), ) - async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus - ) -> None: + async def _listen(self, started_listening: asyncio.Event): async with redis.Redis(connection_pool=self.connection_pool) as conn: pubsub = conn.pubsub() await pubsub.subscribe(self.CHANNEL_NAME) - started_listening_task_status.started() + # started_listening.set() + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) while True: - message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None) + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=None + ) if message is None: continue @@ -45,8 +50,7 @@ async def listen_for_cancellation( continue cancellation_message = model_validate( - CancellationMessage, - self.serializer.loadb(message['data']) + CancellationMessage, self.serializer.loadb(message["data"]) ) - if cancellation_message.task_id == task_id: - raise TaskCancellationException() + for queue in self.queues: + await queue.put(cancellation_message) diff --git a/src/taskiq_cancellation/state_holders/__init__.py b/src/taskiq_cancellation/state_holders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/taskiq_cancellation/state_holders/in_memory.py b/src/taskiq_cancellation/state_holders/in_memory.py new file mode 100644 index 0000000..23a7089 --- /dev/null +++ b/src/taskiq_cancellation/state_holders/in_memory.py @@ -0,0 +1,16 @@ +from taskiq_cancellation.abc import CancellationStateHolder + + +class InMemoryCancellationStateHolder(CancellationStateHolder): + """In memory cancellation state holder used for testing""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.state_holder: dict[str, bool] = {} + + async def cancel(self, task_id: str) -> None: + self.state_holder[task_id] = True + + async def is_cancelled(self, task_id: str) -> bool: + return self.state_holder.get(task_id, False) diff --git a/src/taskiq_cancellation/state_holders/null.py b/src/taskiq_cancellation/state_holders/null.py new file mode 100644 index 0000000..4383071 --- /dev/null +++ b/src/taskiq_cancellation/state_holders/null.py @@ -0,0 +1,15 @@ +from taskiq_cancellation.abc import CancellationStateHolder + + +class NullCancellationStateHolder(CancellationStateHolder): + """ + \"Do nothing\" cancellation state holder + + May be useful if there's no need or ability to use an actual state holder + """ + + async def cancel(self, task_id: str) -> None: + pass + + async def is_cancelled(self, task_id: str) -> bool: + return False diff --git a/src/taskiq_cancellation/integrations/redis/state_holder.py b/src/taskiq_cancellation/state_holders/redis.py similarity index 85% rename from src/taskiq_cancellation/integrations/redis/state_holder.py rename to src/taskiq_cancellation/state_holders/redis.py index 4ca397c..262f121 100644 --- a/src/taskiq_cancellation/integrations/redis/state_holder.py +++ b/src/taskiq_cancellation/state_holders/redis.py @@ -4,9 +4,9 @@ class RedisCancellationStateHolder(CancellationStateHolder): - def __init__(self, url: str, **connection_kwargs) -> None: - super().__init__() - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) + def __init__(self, url: str, **kwargs) -> None: + super().__init__(**kwargs) + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs) async def cancel(self, task_id: str) -> None: async with redis.Redis(connection_pool=self.connection_pool) as conn: diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index b1cda03..4d659f0 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -12,7 +12,7 @@ def combines(wrapped): In cases of parameter collision wrapper parameters will be used Example: - ''' + ''' import inspect def decorator(func): @@ -20,41 +20,44 @@ def decorator(func): def wrapper(c: int, *args, **kwargs): return foo(*args, **kwargs) * c return wrapper - + @decorator def foo(a: int, b = "lol"): return b * a - + print(inspect.signature(foo)) # (a: int, c: int, b='lol', *args, **kwargs) ''' """ wrapped_signature: inspect.Signature = inspect.signature(wrapped) wrapped_type_hints: typing.Dict[str, str] = typing.get_type_hints(wrapped) - + def decorator(wrapper): wrapper_signature = inspect.signature(wrapper) wrapper_type_hints = typing.get_type_hints(wrapper) for param_name in wrapped_signature.parameters.keys(): if param_name in wrapper_signature.parameters.keys(): - logging.warning(f"Parameter {param_name} will be overwritten by wrapper function") - - parameters = OrderedDict(wrapped_signature.parameters, **wrapper_signature.parameters) + logging.warning( + f"Parameter {param_name} will be overwritten by wrapper function" + ) + + parameters = OrderedDict( + wrapped_signature.parameters, **wrapper_signature.parameters + ) parameters = sorted( parameters.values(), - key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0) + key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0), ) new_return_annotation: inspect.Signature if wrapper_signature.return_annotation is not None: - new_return_annotation = wrapper_signature.return_annotation + new_return_annotation = wrapper_signature.return_annotation else: - new_return_annotation = wrapped_signature.return_annotation + new_return_annotation = wrapped_signature.return_annotation new_signature = inspect.Signature( - parameters=parameters, - return_annotation=new_return_annotation + parameters=parameters, return_annotation=new_return_annotation ) new_annotations = dict(wrapped_type_hints, **wrapper_type_hints) @@ -63,9 +66,8 @@ def decorator(wrapper): wrapper.__signature__ = new_signature return wrapper + return decorator -__all__ = [ - "combines" -] +__all__ = ["combines"] diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py new file mode 100644 index 0000000..a790d1e --- /dev/null +++ b/tests/test_cancellation.py @@ -0,0 +1,58 @@ +import pytest +import asyncio + +from taskiq import AsyncBroker, InMemoryBroker + +from taskiq_cancellation.abc import CancellationBackend +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend +from taskiq_cancellation.exceptions import TaskCancellationException + + +@pytest.fixture +def broker(): + return InMemoryBroker() + + +@pytest.fixture +def backend(broker): + return InMemoryCancellationBackend().with_broker(broker) + + +@pytest.mark.asyncio +async def test_task_success(broker: AsyncBroker, backend: CancellationBackend): + """Test that cancellable task can run successfully""" + + @broker.task + @backend.cancellable + async def task(): + await asyncio.sleep(0.1) + + await broker.startup() + + t = await task.kiq() + + result = await t.wait_result() + assert result.is_err is False + + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_task_cancellation(broker: AsyncBroker, backend: CancellationBackend): + """Test that cancellable task can be cancelled""" + + @broker.task + @backend.cancellable + async def task(): + await asyncio.sleep(0.3) + + await broker.startup() + + t = await task.kiq() + await backend.cancel(t.task_id) + + with pytest.raises(TaskCancellationException): + result = await t.wait_result() + result.raise_for_error() + + await broker.shutdown()