From 2a65c42d3bb8b35216ddf0ded629640b734d5d3a Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sat, 8 Nov 2025 23:53:12 +0300 Subject: [PATCH 01/18] feat: support level task cancellation (asyncio) --- src/taskiq_cancellation/abc/__init__.py | 3 +- src/taskiq_cancellation/abc/backend.py | 252 +++++++++++++----- src/taskiq_cancellation/abc/notifier.py | 5 +- .../abc/started_listening_event.py | 11 + src/taskiq_cancellation/notifiers/null.py | 11 +- src/taskiq_cancellation/notifiers/queue.py | 8 +- src/taskiq_cancellation/utils.py | 23 +- tests/test_cancellation.py | 209 ++++++++++++--- 8 files changed, 412 insertions(+), 110 deletions(-) create mode 100644 src/taskiq_cancellation/abc/started_listening_event.py diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py index 1decfa6..0e98393 100644 --- a/src/taskiq_cancellation/abc/__init__.py +++ b/src/taskiq_cancellation/abc/__init__.py @@ -1,6 +1,7 @@ from .backend import CancellationBackend from .notifier import CancellationNotifier from .state_holder import CancellationStateHolder +from .started_listening_event import StartedListeningEvent -__all__ = ["CancellationBackend", "CancellationNotifier", "CancellationStateHolder"] +__all__ = ["CancellationBackend", "CancellationNotifier", "CancellationStateHolder", "StartedListeningEvent"] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index c3befb9..eb34c27 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,4 +1,6 @@ import abc +import enum +import inspect import asyncio from typing import Callable, Annotated, TypeVar, Awaitable from typing_extensions import ParamSpec, Self @@ -7,14 +9,21 @@ from anyio.abc import TaskStatus from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState -from taskiq_cancellation.utils import combines +from taskiq_cancellation.utils import combines, StopTaskGroupException from taskiq_cancellation.exceptions import TaskCancellationException +from .started_listening_event import StartedListeningEvent + P = ParamSpec("P") R = TypeVar("R") +class CancellationType(str, enum.Enum): + EDGE = "edge" + LEVEL = "level" + + class CancellationBackend(abc.ABC): """ Base class for cancellation backend @@ -48,7 +57,7 @@ async def cancel(self, task_id: str) -> None: @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: """ Listens for cancellation messages and raises :ref:`TaskCancellationException` when @@ -121,7 +130,10 @@ def with_broker(self, broker: AsyncBroker) -> Self: return self - def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def cancellable( + self, + cancellation_type: CancellationType = CancellationType.EDGE + ) -> Callable[[Callable[..., Awaitable]], Callable[..., Awaitable]]: """ Decorator that makes funcion cancellable @@ -136,73 +148,185 @@ def cancellable(self, task: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[ :param task: Task function to wrap :returns: Cancellable task function """ - # Executor type depends on receiver configuration which we can't accessed in any way - if not asyncio.iscoroutinefunction(task): - raise ValueError("Can't cancel synchronous function") - - @combines(task) - async def wrapper( - *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs - ): - task_id = __taskiq_context.message.task_id - result = None - - listener_exception: Exception | None = None - task_exception: Exception | None = None - cancelled_by_request: bool = False - - async with anyio.create_task_group() as group: - - async def listen_for_cancellation( - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ): - nonlocal listener_exception, cancelled_by_request - - try: - await self.listen_for_cancellation(task_id, task_status) - except TaskCancellationException: - cancelled_by_request = True - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - listener_exception = e - finally: - group.cancel_scope.cancel() - - async def call_task(): - nonlocal result, task_exception - - try: - result = await task(*args, **kwargs) - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - task_exception = e - finally: - group.cancel_scope.cancel() - # Listen before checking for cancellation in state holder - # so the message won't get lost in non-persistent queues - await group.start(listen_for_cancellation) - if await self.is_cancelled(task_id): - cancelled_by_request = True - group.cancel_scope.cancel() + def decorator(task: Callable[P, Awaitable[R]]) -> Callable[..., Awaitable[R]]: + # Executor type depends on receiver configuration which we can't accessed in any way + if not inspect.iscoroutinefunction(task): + raise ValueError("Can't cancel synchronous function") + + @combines(task) + async def wrapper( + *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs + ): + task_id = __taskiq_context.message.task_id + + if cancellation_type is CancellationType.EDGE: + task_wrapper = EdgeCancellationWrapper(self, task, task_id) + return await task_wrapper(*args, **kwargs) + elif cancellation_type is CancellationType.LEVEL: + task_wrapper = LevelCancellationWrapper(self, task, task_id) + return await task_wrapper(*args, **kwargs) else: - group.start_soon(call_task) - - if task_exception is not None: - raise task_exception - elif cancelled_by_request: - raise TaskCancellationException() - elif listener_exception is not None: - raise listener_exception - else: - return result + raise ValueError(f"Unknown cancellation type: {cancellation_type!r}") - return wrapper + return wrapper + return decorator + async def _broker_startup_handler(self, _: TaskiqState) -> None: await self.startup() async def _broker_shutdown_handler(self, _: TaskiqState) -> None: await self.shutdown() + + +class EdgeCancellationWrapper: + class ListeningEvent(StartedListeningEvent): + def __init__(self, task_status: TaskStatus) -> None: + self.task_status = task_status + + async def set(self): + self.task_status.started() + + async def wait(self): + # Can ignore, won't execute further before task status is set + pass + + def __init__(self, backend: CancellationBackend, task: Callable, task_id: str): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args, **kwargs): + result = None + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async with anyio.create_task_group() as group: + async def listen_for_cancellation( + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ): + nonlocal listener_exception, cancelled_by_request + + event = self.ListeningEvent(task_status) + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + listener_exception = e + finally: + group.cancel_scope.cancel() + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + task_exception = e + finally: + group.cancel_scope.cancel() + + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + await group.start(listen_for_cancellation) + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + group.cancel_scope.cancel() + else: + group.start_soon(call_task) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result + + +class LevelCancellationWrapper: + class ListeningEvent(StartedListeningEvent): + def __init__(self) -> None: + self.event = asyncio.Event() + + async def set(self): + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self.event.set) + + async def wait(self): + await self.event.wait() + + def __init__(self, backend: CancellationBackend, task: Callable, task_id: str): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args, **kwargs): + result = None + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async def listen_for_cancellation(event: StartedListeningEvent): + nonlocal listener_exception, cancelled_by_request + + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + raise + except asyncio.CancelledError: + raise + except Exception as e: + listener_exception = e + raise + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except asyncio.CancelledError: + raise + except Exception as e: + task_exception = e + raise + + try: + async with asyncio.TaskGroup() as tg: + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + event = self.ListeningEvent() + tg.create_task(listen_for_cancellation(event)) + await event.wait() + + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + raise StopTaskGroupException() + + task_task = asyncio.create_task(call_task()) + await task_task + if not task_task.cancelled(): + raise StopTaskGroupException() + except Exception: + # Exceptions are stored in local vars, can ignore + pass + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 3ba408f..5364e89 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -1,9 +1,10 @@ import abc -from anyio.abc import TaskStatus from taskiq.abc.serializer import TaskiqSerializer from taskiq.serializers import JSONSerializer +from .started_listening_event import StartedListeningEvent + class CancellationNotifier(abc.ABC): """Receives cancellation messages and notifies listeners of these messages""" @@ -31,7 +32,7 @@ async def cancel(self, task_id: str) -> None: @abc.abstractmethod async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: """ Listens for cancellation messages and raises :ref:`TaskCancellationException` when diff --git a/src/taskiq_cancellation/abc/started_listening_event.py b/src/taskiq_cancellation/abc/started_listening_event.py new file mode 100644 index 0000000..28a51eb --- /dev/null +++ b/src/taskiq_cancellation/abc/started_listening_event.py @@ -0,0 +1,11 @@ +import abc + + +class StartedListeningEvent(abc.ABC): + @abc.abstractmethod + async def set(self): + pass + + @abc.abstractmethod + async def wait(self): + pass diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py index 9740012..7530df5 100644 --- a/src/taskiq_cancellation/notifiers/null.py +++ b/src/taskiq_cancellation/notifiers/null.py @@ -1,7 +1,6 @@ import asyncio -from anyio.abc import TaskStatus -from taskiq_cancellation.abc.notifier import CancellationNotifier +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent class NullCancellationNotifier(CancellationNotifier): @@ -14,6 +13,10 @@ class NullCancellationNotifier(CancellationNotifier): async def cancel(self, task_id: str) -> None: pass - async def listen_for_cancellation(self, task_id: str, started_listening_task_status: TaskStatus) -> None: - started_listening_task_status.started() + async def listen_for_cancellation( + self, + task_id: str, + started_listening_event: StartedListeningEvent + ) -> None: + await started_listening_event.set() await asyncio.sleep(float("+inf")) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py index 3190501..f7f8b6d 100644 --- a/src/taskiq_cancellation/notifiers/queue.py +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -2,9 +2,7 @@ import weakref import asyncio -from anyio.abc import TaskStatus - -from taskiq_cancellation.abc import CancellationNotifier +from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent from taskiq_cancellation.exceptions import TaskCancellationException from taskiq_cancellation.message import CancellationMessage @@ -30,7 +28,7 @@ async def shutdown(self) -> None: self.listener_task.cancel() async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: cancellations: asyncio.Queue[CancellationMessage] = asyncio.Queue() @@ -38,7 +36,7 @@ async def listen_for_cancellation( await self._create_listener_task() await self._subscribe(cancellations) - started_listening_task_status.started() + await started_listening_event.set() while True: cancellation_message = await cancellations.get() diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index 4d659f0..d94311b 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -5,7 +5,7 @@ from collections import OrderedDict -def combines(wrapped): +def combines(wrapped, add_var_parameters=False): """ Combines wrapped and wrapper functions signatures and type hints @@ -28,6 +28,11 @@ def foo(a: int, b = "lol"): print(inspect.signature(foo)) # (a: int, c: int, b='lol', *args, **kwargs) ''' + + :param wrapped: function to be wrapped + :type wrapped: Callable + :param add_var_parameters: add *args and **kwargs from wrapper to new signature + :type add_var_parameters: bool """ wrapped_signature: inspect.Signature = inspect.signature(wrapped) wrapped_type_hints: typing.Dict[str, str] = typing.get_type_hints(wrapped) @@ -42,8 +47,18 @@ def decorator(wrapper): f"Parameter {param_name} will be overwritten by wrapper function" ) + wrapper_parameters = OrderedDict() + for name, parameter in wrapper_signature.parameters.items(): + if not add_var_parameters: + if any(( + parameter.kind is inspect.Parameter.VAR_POSITIONAL, + parameter.kind is inspect.Parameter.VAR_KEYWORD + )): + continue + wrapper_parameters[name] = parameter + parameters = OrderedDict( - wrapped_signature.parameters, **wrapper_signature.parameters + wrapped_signature.parameters, **wrapper_parameters ) parameters = sorted( parameters.values(), @@ -70,4 +85,8 @@ def decorator(wrapper): return decorator +class StopTaskGroupException(Exception): + pass + + __all__ = ["combines"] diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index a790d1e..e3f6727 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -1,9 +1,10 @@ import pytest import asyncio +import anyio from taskiq import AsyncBroker, InMemoryBroker -from taskiq_cancellation.abc import CancellationBackend +from taskiq_cancellation.abc.backend import CancellationBackend, CancellationType from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend from taskiq_cancellation.exceptions import TaskCancellationException @@ -14,45 +15,189 @@ def broker(): @pytest.fixture -def backend(broker): +def backend(broker: AsyncBroker): return InMemoryCancellationBackend().with_broker(broker) -@pytest.mark.asyncio -async def test_task_success(broker: AsyncBroker, backend: CancellationBackend): - """Test that cancellable task can run successfully""" +class TestLevelCancellation: + @pytest.mark.asyncio + async def test_task_success(self, broker: AsyncBroker, backend: CancellationBackend): + @broker.task + @backend.cancellable(cancellation_type=CancellationType.LEVEL) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False + + await broker.shutdown() + + @pytest.mark.asyncio + async def test_task_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): + @broker.task + @backend.cancellable(cancellation_type=CancellationType.LEVEL) + async def test_task(): + await asyncio.sleep(0.2) + raise ValueError() + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + @pytest.mark.asyncio + async def test_cancellation_interception(self, broker: AsyncBroker, backend: CancellationBackend): + cancelled_for_second_time = False + + task_started = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.LEVEL) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + task_started.set() + await asyncio.sleep(0.5) + except asyncio.CancelledError: + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + await task_started.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is False + + await broker.shutdown() + + +class TestEdgeCancellation: + @pytest.mark.asyncio + async def test_task_success(self, broker: AsyncBroker, backend: CancellationBackend): + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False + + await broker.shutdown() + + @pytest.mark.asyncio + async def test_task_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + @pytest.mark.asyncio + async def test_repeated_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): + cancelled_for_second_time = False + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + nonlocal cancelled_for_second_time + + try: + started_event.set() + await asyncio.sleep(1) + except anyio.get_cancelled_exc_class(): + # anyio cancels on any await after scope's cancellation + try: + await asyncio.sleep(0) + except anyio.get_cancelled_exc_class(): + cancelled_for_second_time = True + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + assert cancelled_for_second_time is True + + await broker.shutdown() + +# @pytest.mark.asyncio +# async def test_task_success(broker: AsyncBroker, backend: CancellationBackend): +# """Test that cancellable task can run successfully""" + +# @broker.task +# @backend.cancellable +# async def task(): +# await asyncio.sleep(0.1) + +# await broker.startup() + +# t = await task.kiq() - @broker.task - @backend.cancellable - async def task(): - await asyncio.sleep(0.1) +# result = await t.wait_result() +# assert result.is_err is False - await broker.startup() +# await broker.shutdown() - t = await task.kiq() - result = await t.wait_result() - assert result.is_err is False +# @pytest.mark.asyncio +# async def test_task_cancellation(broker: AsyncBroker, backend: CancellationBackend): +# """Test that cancellable task can be cancelled""" - await broker.shutdown() +# @broker.task +# @backend.cancellable +# async def task(): +# await asyncio.sleep(0.3) +# await broker.startup() + +# t = await task.kiq() +# await backend.cancel(t.task_id) -@pytest.mark.asyncio -async def test_task_cancellation(broker: AsyncBroker, backend: CancellationBackend): - """Test that cancellable task can be cancelled""" +# with pytest.raises(TaskCancellationException): +# result = await t.wait_result() +# result.raise_for_error() - @broker.task - @backend.cancellable - async def task(): - await asyncio.sleep(0.3) - - await broker.startup() - - t = await task.kiq() - await backend.cancel(t.task_id) - - with pytest.raises(TaskCancellationException): - result = await t.wait_result() - result.raise_for_error() - - await broker.shutdown() +# await broker.shutdown() From 93ab0091543a43fdc70b8d323ab00a87be2a03a2 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sat, 8 Nov 2025 23:54:23 +0300 Subject: [PATCH 02/18] test: level and edge cancellation behaviour tests --- tests/test_cancellation.py | 186 +++++++++++++++---------------------- 1 file changed, 77 insertions(+), 109 deletions(-) diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index e3f6727..2cd665e 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -19,47 +19,80 @@ def backend(broker: AsyncBroker): return InMemoryCancellationBackend().with_broker(broker) -class TestLevelCancellation: - @pytest.mark.asyncio - async def test_task_success(self, broker: AsyncBroker, backend: CancellationBackend): - @broker.task - @backend.cancellable(cancellation_type=CancellationType.LEVEL) - async def test_task(): - await asyncio.sleep(0.1) - - await broker.startup() +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) +) +async def test_task_success( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, +): + """Tests that cancellable task can successfully finish""" + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False + + await broker.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) +) +async def test_task_cancellation( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, +): + """Tests that cancellable task can successfully cancel""" + + started_event = asyncio.Event() + + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + with pytest.raises(anyio.get_cancelled_exc_class()): + started_event.set() + await asyncio.sleep(0.2) - task = await test_task.kiq() - result = await task.wait_result() - assert result.is_err is False + await broker.startup() - await broker.shutdown() + task = await test_task.kiq() + assert await task.is_ready() is False - @pytest.mark.asyncio - async def test_task_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): - @broker.task - @backend.cancellable(cancellation_type=CancellationType.LEVEL) - async def test_task(): - await asyncio.sleep(0.2) - raise ValueError() - - await broker.startup() + await started_event.wait() + await backend.cancel(task.task_id) - task = await test_task.kiq() - assert await task.is_ready() is False + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() - await backend.cancel(task.task_id) + await broker.shutdown() - with pytest.raises(TaskCancellationException): - result = await task.wait_result() - result.raise_for_error() - await broker.shutdown() - +class TestLevelCancellation: @pytest.mark.asyncio - async def test_cancellation_interception(self, broker: AsyncBroker, backend: CancellationBackend): - cancelled_for_second_time = False + async def test_cancellation_interception( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that tasks can capture task cancellation + + asyncio raises asyncio.CancelledError only once. Task can intercept that to do cleanup, + but also can just ignore the cancellation request. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + """ + cancelled_for_second_time = False task_started = asyncio.Event() @broker.task @@ -75,12 +108,12 @@ async def test_task(): await asyncio.sleep(0) except asyncio.CancelledError: cancelled_for_second_time = True - + await broker.startup() task = await test_task.kiq() assert await task.is_ready() is False - + await task_started.wait() await backend.cancel(task.task_id) @@ -94,42 +127,16 @@ async def test_task(): class TestEdgeCancellation: @pytest.mark.asyncio - async def test_task_success(self, broker: AsyncBroker, backend: CancellationBackend): - @broker.task - @backend.cancellable(cancellation_type=CancellationType.EDGE) - async def test_task(): - await asyncio.sleep(0.1) - - await broker.startup() + async def test_repeated_cancellation( + self, broker: AsyncBroker, backend: CancellationBackend + ): + """ + Tests that task will have multiple cancellation exceptions - task = await test_task.kiq() - result = await task.wait_result() - assert result.is_err is False + anyio raises cancellation exception on every await + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ - await broker.shutdown() - - @pytest.mark.asyncio - async def test_task_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): - @broker.task - @backend.cancellable(cancellation_type=CancellationType.EDGE) - async def test_task(): - await asyncio.sleep(0.2) - - await broker.startup() - - task = await test_task.kiq() - assert await task.is_ready() is False - - await backend.cancel(task.task_id) - - with pytest.raises(TaskCancellationException): - result = await task.wait_result() - result.raise_for_error() - - await broker.shutdown() - - @pytest.mark.asyncio - async def test_repeated_cancellation(self, broker: AsyncBroker, backend: CancellationBackend): cancelled_for_second_time = False started_event = asyncio.Event() @@ -140,7 +147,7 @@ async def test_task(): try: started_event.set() - await asyncio.sleep(1) + await asyncio.sleep(0.5) except anyio.get_cancelled_exc_class(): # anyio cancels on any await after scope's cancellation try: @@ -152,7 +159,7 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - + await started_event.wait() await backend.cancel(task.task_id) @@ -162,42 +169,3 @@ async def test_task(): assert cancelled_for_second_time is True await broker.shutdown() - -# @pytest.mark.asyncio -# async def test_task_success(broker: AsyncBroker, backend: CancellationBackend): -# """Test that cancellable task can run successfully""" - -# @broker.task -# @backend.cancellable -# async def task(): -# await asyncio.sleep(0.1) - -# await broker.startup() - -# t = await task.kiq() - -# result = await t.wait_result() -# assert result.is_err is False - -# await broker.shutdown() - - -# @pytest.mark.asyncio -# async def test_task_cancellation(broker: AsyncBroker, backend: CancellationBackend): -# """Test that cancellable task can be cancelled""" - -# @broker.task -# @backend.cancellable -# async def task(): -# await asyncio.sleep(0.3) - -# await broker.startup() - -# t = await task.kiq() -# await backend.cancel(t.task_id) - -# with pytest.raises(TaskCancellationException): -# result = await t.wait_result() -# result.raise_for_error() - -# await broker.shutdown() From c9eb438e0561f9ef1e8a198300a7bbd55b27dd4a Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 00:26:45 +0300 Subject: [PATCH 03/18] feat: allow cancellable decorator to omit parentesis --- src/taskiq_cancellation/abc/backend.py | 85 +++++++++++++++++--------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index eb34c27..d3eb233 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -2,8 +2,8 @@ import enum import inspect import asyncio -from typing import Callable, Annotated, TypeVar, Awaitable -from typing_extensions import ParamSpec, Self +from typing import Callable, Annotated, Awaitable, overload, Optional, cast, TypeAlias +from typing_extensions import Self import anyio from anyio.abc import TaskStatus @@ -15,8 +15,7 @@ from .started_listening_event import StartedListeningEvent -P = ParamSpec("P") -R = TypeVar("R") +AsyncCallable: TypeAlias = Callable[..., Awaitable] class CancellationType(str, enum.Enum): @@ -129,11 +128,25 @@ def with_broker(self, broker: AsyncBroker) -> Self: ) return self + + @overload + def cancellable( + self, + cancellation_type: AsyncCallable + ) -> AsyncCallable: + pass + @overload def cancellable( self, - cancellation_type: CancellationType = CancellationType.EDGE - ) -> Callable[[Callable[..., Awaitable]], Callable[..., Awaitable]]: + cancellation_type: Optional[CancellationType] = None + ) -> Callable[[AsyncCallable], AsyncCallable]: + pass + + def cancellable( + self, + cancellation_type = None + ): """ Decorator that makes funcion cancellable @@ -149,29 +162,43 @@ def cancellable( :returns: Cancellable task function """ - def decorator(task: Callable[P, Awaitable[R]]) -> Callable[..., Awaitable[R]]: - # Executor type depends on receiver configuration which we can't accessed in any way - if not inspect.iscoroutinefunction(task): - raise ValueError("Can't cancel synchronous function") - - @combines(task) - async def wrapper( - *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs - ): - task_id = __taskiq_context.message.task_id - - if cancellation_type is CancellationType.EDGE: - task_wrapper = EdgeCancellationWrapper(self, task, task_id) - return await task_wrapper(*args, **kwargs) - elif cancellation_type is CancellationType.LEVEL: - task_wrapper = LevelCancellationWrapper(self, task, task_id) - return await task_wrapper(*args, **kwargs) - else: - raise ValueError(f"Unknown cancellation type: {cancellation_type!r}") - - return wrapper - return decorator - + defaults = { + "cancellation_type": CancellationType.EDGE + } + + def make_decorator( + cancellation_type: CancellationType + ): + def decorator(task: AsyncCallable) -> AsyncCallable: + # Executor type depends on receiver configuration which we can't accessed in any way + if not inspect.iscoroutinefunction(task): + raise ValueError("Can't cancel synchronous function") + + @combines(task) + async def wrapper( + *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs + ): + task_id = __taskiq_context.message.task_id + + if cancellation_type is CancellationType.EDGE: + task_wrapper = EdgeCancellationWrapper(self, task, task_id) + return await task_wrapper(*args, **kwargs) + elif cancellation_type is CancellationType.LEVEL: + task_wrapper = LevelCancellationWrapper(self, task, task_id) + return await task_wrapper(*args, **kwargs) + else: + raise ValueError(f"Unknown cancellation type: {cancellation_type!r}") + + return wrapper + return decorator + + if callable(cancellation_type): + task = cast(Callable[..., Awaitable], cancellation_type) + return make_decorator(**defaults)(task) + else: + return make_decorator( + cancellation_type=cancellation_type or defaults["cancellation_type"] + ) async def _broker_startup_handler(self, _: TaskiqState) -> None: await self.startup() From 69746a9105d8eb9f1600c37917c1dd9da2dbf2b6 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 00:41:45 +0300 Subject: [PATCH 04/18] fix: allow direct function calls --- src/taskiq_cancellation/abc/backend.py | 8 +++++++- tests/test_cancellation.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index d3eb233..4b895f7 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -176,8 +176,14 @@ def decorator(task: AsyncCallable) -> AsyncCallable: @combines(task) async def wrapper( - *args, __taskiq_context: Annotated[Context, TaskiqDepends()], **kwargs + *args, + __taskiq_context: Annotated[Context, TaskiqDepends()] = None, # type: ignore + **kwargs ): + if __taskiq_context is None: + # Ran the function directly, without kiq + return task(*args, **kwargs) + task_id = __taskiq_context.message.task_id if cancellation_type is CancellationType.EDGE: diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 2cd665e..1b51f3f 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -19,6 +19,17 @@ def backend(broker: AsyncBroker): return InMemoryCancellationBackend().with_broker(broker) +@pytest.mark.asyncio +async def test_task_direct_call(broker: AsyncBroker, backend: CancellationBackend): + @broker.task + @backend.cancellable() + async def test_task(): + return True + + result = await test_task() + assert result + + @pytest.mark.asyncio @pytest.mark.parametrize( ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) From 9e46f4ab74f2ee58132c695c2d7c38ef057d7be3 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 00:43:17 +0300 Subject: [PATCH 05/18] test: test cancellable w/o parentesis --- tests/test_backend.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_backend.py diff --git a/tests/test_backend.py b/tests/test_backend.py new file mode 100644 index 0000000..81d2a8c --- /dev/null +++ b/tests/test_backend.py @@ -0,0 +1,35 @@ +import pytest +import inspect + +from taskiq import AsyncBroker, InMemoryBroker + +from taskiq_cancellation.abc.backend import CancellationBackend +from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend + + +@pytest.fixture +def backend(): + return InMemoryCancellationBackend() + + +@pytest.mark.asyncio +async def test_decorator_without_parentesis(backend: CancellationBackend): + @backend.cancellable + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + + +@pytest.mark.asyncio +async def test_decorator_with_parentesis(backend: CancellationBackend): + @backend.cancellable() + async def test_task(): + pass + + task = test_task() + assert inspect.iscoroutine(task) + await task + \ No newline at end of file From 3162c18090ee637f669b5b5f990c0d01f0c31ef6 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 00:44:28 +0300 Subject: [PATCH 06/18] fix,test: remove deadlock from waiting for task to start --- tests/test_cancellation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 1b51f3f..71b6d33 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -80,7 +80,8 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - await started_event.wait() + async with asyncio.timeout(1): + await started_event.wait() await backend.cancel(task.task_id) with pytest.raises(TaskCancellationException): @@ -125,7 +126,8 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - await task_started.wait() + async with asyncio.timeout(1): + await task_started.wait() await backend.cancel(task.task_id) with pytest.raises(TaskCancellationException): @@ -171,7 +173,8 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - await started_event.wait() + async with asyncio.timeout(1): + await started_event.wait() await backend.cancel(task.task_id) with pytest.raises(TaskCancellationException): From acf37248b58116cf1f0ecfc62561046613d2b975 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 00:55:17 +0300 Subject: [PATCH 07/18] fix: await for task when direct calling --- src/taskiq_cancellation/abc/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 4b895f7..a88ff56 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -182,7 +182,7 @@ async def wrapper( ): if __taskiq_context is None: # Ran the function directly, without kiq - return task(*args, **kwargs) + return await task(*args, **kwargs) task_id = __taskiq_context.message.task_id From 6d990f43ab5c3837691eefeccebf9875a77128ce Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 01:03:57 +0300 Subject: [PATCH 08/18] ci: run tests on windows and macos (just like taskiq) --- .github/workflows/run_tests.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3a45a80..fdd04d6 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -6,13 +6,13 @@ on: jobs: run-tests: - runs-on: ubuntu-latest - strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + os: [ubuntu-latest, windows-latest, macos-latest] fail-fast: false + runs-on: ${{ matrix.os }} steps: - name: Checkout uses: actions/checkout@v5 @@ -25,8 +25,8 @@ jobs: - name: Setup uv uses: astral-sh/setup-uv@v7 - - name: Create virtual environment - run: uv venv .venv && source .venv/bin/activate + # - name: Create virtual environment + # run: uv venv .venv && source .venv/bin/activate - name: Install modules run: uv sync From 9b6df4f62b641f43dea65d20d01ddf1a9fa8f3e0 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 01:23:19 +0300 Subject: [PATCH 09/18] fix: add async_timeout for Python <3.11 --- pyproject.toml | 6 +++++- src/taskiq_cancellation/abc/backend.py | 7 ++++++- tests/test_cancellation.py | 12 +++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 869d263..8a9e60f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,11 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["taskiq", "typing-extensions>=4.13.2"] +dependencies = [ + "async-timeout>=5.0.1 ; python_full_version < '3.11'", + "taskiq", + "typing-extensions>=4.15.0 ; python_full_version < '3.11'", +] [project.optional-dependencies] redis = ["redis~=3.0"] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index a88ff56..1231e5f 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,9 +1,14 @@ import abc +import sys import enum import inspect import asyncio from typing import Callable, Annotated, Awaitable, overload, Optional, cast, TypeAlias -from typing_extensions import Self + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self import anyio from anyio.abc import TaskStatus diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 71b6d33..54531d6 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -1,6 +1,12 @@ +import sys import pytest import asyncio +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + import anyio from taskiq import AsyncBroker, InMemoryBroker @@ -80,7 +86,7 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - async with asyncio.timeout(1): + async with timeout(1): await started_event.wait() await backend.cancel(task.task_id) @@ -126,7 +132,7 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - async with asyncio.timeout(1): + async with timeout(1): await task_started.wait() await backend.cancel(task.task_id) @@ -173,7 +179,7 @@ async def test_task(): task = await test_task.kiq() assert await task.is_ready() is False - async with asyncio.timeout(1): + async with timeout(1): await started_event.wait() await backend.cancel(task.task_id) From 253cf73cb177fbe0351d8adf08440539d1db568e Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 01:25:15 +0300 Subject: [PATCH 10/18] chore: ruff check fixes and ruff format --- src/taskiq_cancellation/__init__.py | 5 +- src/taskiq_cancellation/abc/__init__.py | 7 +- src/taskiq_cancellation/abc/backend.py | 64 +++++++++---------- src/taskiq_cancellation/abc/notifier.py | 8 +-- src/taskiq_cancellation/abc/state_holder.py | 2 +- src/taskiq_cancellation/backends/in_memory.py | 2 +- src/taskiq_cancellation/backends/modular.py | 13 ++-- .../notifiers/in_memory.py | 9 +-- src/taskiq_cancellation/notifiers/null.py | 8 +-- src/taskiq_cancellation/notifiers/queue.py | 1 + .../state_holders/in_memory.py | 4 +- src/taskiq_cancellation/state_holders/null.py | 4 +- src/taskiq_cancellation/utils.py | 16 ++--- tests/test_backend.py | 3 - tests/test_cancellation.py | 2 +- 15 files changed, 71 insertions(+), 77 deletions(-) diff --git a/src/taskiq_cancellation/__init__.py b/src/taskiq_cancellation/__init__.py index 57864e4..db24b56 100644 --- a/src/taskiq_cancellation/__init__.py +++ b/src/taskiq_cancellation/__init__.py @@ -2,7 +2,4 @@ from .backends.modular import ModularCancellationBackend -__all__ = [ - "CancellationBackend", - "ModularCancellationBackend" -] +__all__ = ["CancellationBackend", "ModularCancellationBackend"] diff --git a/src/taskiq_cancellation/abc/__init__.py b/src/taskiq_cancellation/abc/__init__.py index 0e98393..45239e2 100644 --- a/src/taskiq_cancellation/abc/__init__.py +++ b/src/taskiq_cancellation/abc/__init__.py @@ -4,4 +4,9 @@ from .started_listening_event import StartedListeningEvent -__all__ = ["CancellationBackend", "CancellationNotifier", "CancellationStateHolder", "StartedListeningEvent"] +__all__ = [ + "CancellationBackend", + "CancellationNotifier", + "CancellationStateHolder", + "StartedListeningEvent", +] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 1231e5f..d3fed52 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -32,6 +32,7 @@ class CancellationBackend(abc.ABC): """ Base class for cancellation backend """ + def __init__(self) -> None: super().__init__() @@ -41,7 +42,7 @@ def __init__(self) -> None: async def is_cancelled(self, task_id: str) -> bool: """ Checks if a task with task id of *task_id* is set to be cancelled - + :param task_id: task id to check :type task_id: str :returns: task cancellation state @@ -67,13 +68,13 @@ async def listen_for_cancellation( Listens for cancellation messages and raises :ref:`TaskCancellationException` when receives :ref:`CancellationMessage` with same id as *task_id*. - This function is used in :func:`cancellable` decorator. - Call `started_listening_task_status.started()` when the listener is ready + This function is used in :func:`cancellable` decorator. + Call `started_listening_task_status.started()` when the listener is ready to receive messages. - :param task_id: id of task that will be listened for + :param task_id: id of task that will be listened for :type task_id: str - :param started_listening_task_status: + :param started_listening_task_status: :type started_listening_task_status: anyio.abc.TaskStatus """ pass @@ -88,7 +89,7 @@ async def startup(self) -> None: async def shutdown(self) -> None: """Shuts down cancellation backend - + Triggered only if backend has a broker set. To set a broker use :ref:`with_broker`. """ pass @@ -97,7 +98,7 @@ def with_broker(self, broker: AsyncBroker) -> Self: """ Set a broker and return updated cancellation backend - Sets up startup and shutdown event handlers for backend's startup + Sets up startup and shutdown event handlers for backend's startup and shutdown methods respectfully :param broker: new broker @@ -133,25 +134,18 @@ def with_broker(self, broker: AsyncBroker) -> Self: ) return self - + @overload - def cancellable( - self, - cancellation_type: AsyncCallable - ) -> AsyncCallable: + def cancellable(self, cancellation_type: AsyncCallable) -> AsyncCallable: pass @overload def cancellable( - self, - cancellation_type: Optional[CancellationType] = None + self, cancellation_type: Optional[CancellationType] = None ) -> Callable[[AsyncCallable], AsyncCallable]: pass - def cancellable( - self, - cancellation_type = None - ): + def cancellable(self, cancellation_type=None): """ Decorator that makes funcion cancellable @@ -167,13 +161,9 @@ def cancellable( :returns: Cancellable task function """ - defaults = { - "cancellation_type": CancellationType.EDGE - } + defaults = {"cancellation_type": CancellationType.EDGE} - def make_decorator( - cancellation_type: CancellationType - ): + def make_decorator(cancellation_type: CancellationType): def decorator(task: AsyncCallable) -> AsyncCallable: # Executor type depends on receiver configuration which we can't accessed in any way if not inspect.iscoroutinefunction(task): @@ -181,9 +171,9 @@ def decorator(task: AsyncCallable) -> AsyncCallable: @combines(task) async def wrapper( - *args, + *args, __taskiq_context: Annotated[Context, TaskiqDepends()] = None, # type: ignore - **kwargs + **kwargs, ): if __taskiq_context is None: # Ran the function directly, without kiq @@ -192,17 +182,20 @@ async def wrapper( task_id = __taskiq_context.message.task_id if cancellation_type is CancellationType.EDGE: - task_wrapper = EdgeCancellationWrapper(self, task, task_id) + task_wrapper = EdgeCancellationWrapper(self, task, task_id) return await task_wrapper(*args, **kwargs) - elif cancellation_type is CancellationType.LEVEL: - task_wrapper = LevelCancellationWrapper(self, task, task_id) + elif cancellation_type is CancellationType.LEVEL: + task_wrapper = LevelCancellationWrapper(self, task, task_id) return await task_wrapper(*args, **kwargs) else: - raise ValueError(f"Unknown cancellation type: {cancellation_type!r}") + raise ValueError( + f"Unknown cancellation type: {cancellation_type!r}" + ) return wrapper + return decorator - + if callable(cancellation_type): task = cast(Callable[..., Awaitable], cancellation_type) return make_decorator(**defaults)(task) @@ -229,7 +222,7 @@ async def set(self): async def wait(self): # Can ignore, won't execute further before task status is set pass - + def __init__(self, backend: CancellationBackend, task: Callable, task_id: str): self.backend = backend self.task = task @@ -243,6 +236,7 @@ async def __call__(self, *args, **kwargs): cancelled_by_request: bool = False async with anyio.create_task_group() as group: + async def listen_for_cancellation( task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ): @@ -295,7 +289,7 @@ class LevelCancellationWrapper: class ListeningEvent(StartedListeningEvent): def __init__(self) -> None: self.event = asyncio.Event() - + async def set(self): loop = asyncio.get_running_loop() loop.call_soon_threadsafe(self.event.set) @@ -351,11 +345,11 @@ async def call_task(): if await self.backend.is_cancelled(self.task_id): cancelled_by_request = True raise StopTaskGroupException() - + task_task = asyncio.create_task(call_task()) await task_task if not task_task.cancelled(): - raise StopTaskGroupException() + raise StopTaskGroupException() except Exception: # Exceptions are stored in local vars, can ignore pass diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 5364e89..938618a 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -38,13 +38,13 @@ async def listen_for_cancellation( Listens for cancellation messages and raises :ref:`TaskCancellationException` when receives :ref:`CancellationMessage` with same id as *task_id*. - This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. - Call `started_listening_task_status.started()` when the listener is ready + This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. + Call `started_listening_task_status.started()` when the listener is ready to receive messages. - :param task_id: id of task that will be listened for + :param task_id: id of task that will be listened for :type task_id: str - :param started_listening_task_status: + :param started_listening_task_status: :type started_listening_task_status: anyio.abc.TaskStatus """ pass diff --git a/src/taskiq_cancellation/abc/state_holder.py b/src/taskiq_cancellation/abc/state_holder.py index 8416440..bb2a707 100644 --- a/src/taskiq_cancellation/abc/state_holder.py +++ b/src/taskiq_cancellation/abc/state_holder.py @@ -18,7 +18,7 @@ async def cancel(self, task_id: str) -> None: async def is_cancelled(self, task_id: str) -> bool: """ Checks if a task with task id of *task_id* is set to be cancelled - + :param task_id: task id to check :type task_id: str :returns: task cancellation state diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py index ee5e385..6c8f172 100644 --- a/src/taskiq_cancellation/backends/in_memory.py +++ b/src/taskiq_cancellation/backends/in_memory.py @@ -8,5 +8,5 @@ class InMemoryCancellationBackend(ModularCancellationBackend): def __init__(self, **kwargs): super().__init__( state_holder=InMemoryCancellationStateHolder(**kwargs), - notifier=InMemoryCancellationNotifier(**kwargs) + notifier=InMemoryCancellationNotifier(**kwargs), ) diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py index 2d87dfe..a1e7b49 100644 --- a/src/taskiq_cancellation/backends/modular.py +++ b/src/taskiq_cancellation/backends/modular.py @@ -1,4 +1,8 @@ -from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder +from taskiq_cancellation.abc import ( + CancellationBackend, + CancellationNotifier, + CancellationStateHolder, +) import anyio from anyio.abc import TaskStatus @@ -6,12 +10,13 @@ class ModularCancellationBackend(CancellationBackend): """ - Modular cancellation backend made up of :class:`CancellationStateHolder` + Modular cancellation backend made up of :class:`CancellationStateHolder` and :class:`CancellationNotifier` - - `CancellationStateHolder` stores cancellation state and blocks the task from being run. + - `CancellationStateHolder` stores cancellation state and blocks the task from being run. - `CancellationNotifier` receives cancellation messages and cancels already running tasks. """ + def __init__( self, state_holder: CancellationStateHolder, notifier: CancellationNotifier ): @@ -39,7 +44,7 @@ async def startup(self) -> None: await super().startup() await self.state_holder.startup() await self.notifier.startup() - + async def shutdown(self) -> None: await super().shutdown() await self.state_holder.shutdown() diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py index 748a129..b03dbeb 100644 --- a/src/taskiq_cancellation/notifiers/in_memory.py +++ b/src/taskiq_cancellation/notifiers/in_memory.py @@ -11,17 +11,14 @@ class InMemoryCancellationNotifier(QueueCancellationNotifier): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - + self.messages: asyncio.Queue[CancellationMessage] = asyncio.Queue() async def cancel(self, task_id: str) -> None: timestamp = time.time() - + await self.messages.put( - CancellationMessage( - task_id=task_id, - timestamp=timestamp - ) + CancellationMessage(task_id=task_id, timestamp=timestamp) ) async def _listen(self, started_listening: asyncio.Event) -> None: diff --git a/src/taskiq_cancellation/notifiers/null.py b/src/taskiq_cancellation/notifiers/null.py index 7530df5..bb19211 100644 --- a/src/taskiq_cancellation/notifiers/null.py +++ b/src/taskiq_cancellation/notifiers/null.py @@ -6,17 +6,15 @@ class NullCancellationNotifier(CancellationNotifier): """ \"Do nothing\" cancellation notifier - + May be useful if there's no need or ability to use an actual notifier """ - + async def cancel(self, task_id: str) -> None: pass async def listen_for_cancellation( - self, - task_id: str, - started_listening_event: StartedListeningEvent + self, task_id: str, started_listening_event: StartedListeningEvent ) -> None: await started_listening_event.set() await asyncio.sleep(float("+inf")) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py index f7f8b6d..28ef1ee 100644 --- a/src/taskiq_cancellation/notifiers/queue.py +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -14,6 +14,7 @@ class QueueCancellationNotifier(CancellationNotifier): Requires :func:`_listen` to be implemeted """ + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/src/taskiq_cancellation/state_holders/in_memory.py b/src/taskiq_cancellation/state_holders/in_memory.py index 23a7089..b63277e 100644 --- a/src/taskiq_cancellation/state_holders/in_memory.py +++ b/src/taskiq_cancellation/state_holders/in_memory.py @@ -3,10 +3,10 @@ class InMemoryCancellationStateHolder(CancellationStateHolder): """In memory cancellation state holder used for testing""" - + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - + self.state_holder: dict[str, bool] = {} async def cancel(self, task_id: str) -> None: diff --git a/src/taskiq_cancellation/state_holders/null.py b/src/taskiq_cancellation/state_holders/null.py index 4383071..b24462d 100644 --- a/src/taskiq_cancellation/state_holders/null.py +++ b/src/taskiq_cancellation/state_holders/null.py @@ -4,10 +4,10 @@ class NullCancellationStateHolder(CancellationStateHolder): """ \"Do nothing\" cancellation state holder - + May be useful if there's no need or ability to use an actual state holder """ - + async def cancel(self, task_id: str) -> None: pass diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index d94311b..9a9bd4f 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -50,16 +50,16 @@ def decorator(wrapper): wrapper_parameters = OrderedDict() for name, parameter in wrapper_signature.parameters.items(): if not add_var_parameters: - if any(( - parameter.kind is inspect.Parameter.VAR_POSITIONAL, - parameter.kind is inspect.Parameter.VAR_KEYWORD - )): + if any( + ( + parameter.kind is inspect.Parameter.VAR_POSITIONAL, + parameter.kind is inspect.Parameter.VAR_KEYWORD, + ) + ): continue wrapper_parameters[name] = parameter - - parameters = OrderedDict( - wrapped_signature.parameters, **wrapper_parameters - ) + + parameters = OrderedDict(wrapped_signature.parameters, **wrapper_parameters) parameters = sorted( parameters.values(), key=lambda p: p.kind + (0.5 if p.default != inspect.Parameter.empty else 0), diff --git a/tests/test_backend.py b/tests/test_backend.py index 81d2a8c..4c4b8c2 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,8 +1,6 @@ import pytest import inspect -from taskiq import AsyncBroker, InMemoryBroker - from taskiq_cancellation.abc.backend import CancellationBackend from taskiq_cancellation.backends.in_memory import InMemoryCancellationBackend @@ -32,4 +30,3 @@ async def test_task(): task = test_task() assert inspect.iscoroutine(task) await task - \ No newline at end of file diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 54531d6..5494673 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -32,7 +32,7 @@ async def test_task_direct_call(broker: AsyncBroker, backend: CancellationBacken async def test_task(): return True - result = await test_task() + result = await test_task() assert result From fda260c331b19c30c853644ac3fed4047c3b9f7a Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 01:26:47 +0300 Subject: [PATCH 11/18] fix: adapt counter example to new module structure --- examples/counter/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/counter/main.py b/examples/counter/main.py index a148aae..61f9789 100644 --- a/examples/counter/main.py +++ b/examples/counter/main.py @@ -1,7 +1,7 @@ import asyncio from taskiq_redis import PubSubBroker, RedisAsyncResultBackend -from taskiq_cancellation.integrations.redis import RedisCancellationBackend +from taskiq_cancellation.backends.redis import RedisCancellationBackend url = "redis://localhost" From 4581f462b06f9126666260a12e6f6abc83686a55 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Sun, 9 Nov 2025 22:23:58 +0300 Subject: [PATCH 12/18] fix: a lot of things 1. 3.10 tests crashing with "TypeError: Cannot instantiate typing.Union" Added type to TaskiqDepends so dependency would instantiate correctly 2. 3.9 tests crashing with "ImportError: cannot import name 'TypeAlias' from 'typing'" Added TypeAlias from typing_extensions 3. Rename LevelCancellation to EdgeCancellation and vice versa because my terminology was wrong :p 4. Move cancellation handlers to a separate submodule to restrict edge cancellation to Python 3.11+ because that uses asyncio.TaskGroup 5. InMemoryBackend asyncio.Queue behaviour changes for 3.9 support 6. Typing fixes --- pyproject.toml | 13 +- src/taskiq_cancellation/abc/backend.py | 191 ++---------------- src/taskiq_cancellation/backends/modular.py | 8 +- .../cancellation_handlers/__init__.py | 22 ++ .../cancellation_type.py | 6 + .../cancellation_handlers/edge.py | 110 ++++++++++ .../cancellation_handlers/level.py | 84 ++++++++ .../notifiers/in_memory.py | 11 +- src/taskiq_cancellation/notifiers/queue.py | 3 +- tests/test_cancellation.py | 146 ++++++++----- 10 files changed, 359 insertions(+), 235 deletions(-) create mode 100644 src/taskiq_cancellation/cancellation_handlers/__init__.py create mode 100644 src/taskiq_cancellation/cancellation_handlers/cancellation_type.py create mode 100644 src/taskiq_cancellation/cancellation_handlers/edge.py create mode 100644 src/taskiq_cancellation/cancellation_handlers/level.py diff --git a/pyproject.toml b/pyproject.toml index 8a9e60f..54d2eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,12 @@ check = "mypy --install-types --non-interactive {args:src/taskiq_cancellation te [tool.mypy] ignore_missing_imports = true -exclude = ["examples"] - +exclude = [ + # Added so that "mypy ." would work + "examples", + # Contains Python 3.11+ code, has to be excluded. Runtime checks don't work. + "src/taskiq_cancellation/cancellation_handlers/edge.py", +] [tool.coverage.run] source_pkgs = ["taskiq_cancellation", "tests"] @@ -74,3 +78,8 @@ dev = [ "pytest-asyncio>=0.24.0", "ruff>=0.14.4", ] + +[tool.ruff] +# Edge cancellation is currently Python 3.11+ which causes linter to freak out +# For such versioning cases tests must be written +target-version = "py314" diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index d3fed52..13260c9 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,21 +1,21 @@ import abc import sys -import enum import inspect -import asyncio -from typing import Callable, Annotated, Awaitable, overload, Optional, cast, TypeAlias +from typing import Callable, Annotated, Awaitable, overload, Optional, cast, Union if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, TypeAlias else: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias -import anyio -from anyio.abc import TaskStatus from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState -from taskiq_cancellation.utils import combines, StopTaskGroupException -from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.utils import combines +from taskiq_cancellation.cancellation_handlers import ( + CancellationType, + LevelCancellationHandler, + EdgeCancellationHandler, +) from .started_listening_event import StartedListeningEvent @@ -23,11 +23,6 @@ AsyncCallable: TypeAlias = Callable[..., Awaitable] -class CancellationType(str, enum.Enum): - EDGE = "edge" - LEVEL = "level" - - class CancellationBackend(abc.ABC): """ Base class for cancellation backend @@ -36,7 +31,7 @@ class CancellationBackend(abc.ABC): def __init__(self) -> None: super().__init__() - self.broker: AsyncBroker | None = None + self.broker: Union[AsyncBroker, None] = None @abc.abstractmethod async def is_cancelled(self, task_id: str) -> bool: @@ -161,7 +156,7 @@ def cancellable(self, cancellation_type=None): :returns: Cancellable task function """ - defaults = {"cancellation_type": CancellationType.EDGE} + defaults = {"cancellation_type": CancellationType.LEVEL} def make_decorator(cancellation_type: CancellationType): def decorator(task: AsyncCallable) -> AsyncCallable: @@ -172,7 +167,8 @@ def decorator(task: AsyncCallable) -> AsyncCallable: @combines(task) async def wrapper( *args, - __taskiq_context: Annotated[Context, TaskiqDepends()] = None, # type: ignore + __taskiq_context: Annotated[Context, TaskiqDepends(Context)] = None, # type: ignore + # __taskiq_context: Annotated[Context, TaskiqDepends()], # type: ignore **kwargs, ): if __taskiq_context is None: @@ -182,11 +178,11 @@ async def wrapper( task_id = __taskiq_context.message.task_id if cancellation_type is CancellationType.EDGE: - task_wrapper = EdgeCancellationWrapper(self, task, task_id) - return await task_wrapper(*args, **kwargs) + edge_handler = EdgeCancellationHandler(self, task, task_id) + return await edge_handler(*args, **kwargs) elif cancellation_type is CancellationType.LEVEL: - task_wrapper = LevelCancellationWrapper(self, task, task_id) - return await task_wrapper(*args, **kwargs) + level_handler = LevelCancellationHandler(self, task, task_id) + return await level_handler(*args, **kwargs) else: raise ValueError( f"Unknown cancellation type: {cancellation_type!r}" @@ -209,156 +205,3 @@ async def _broker_startup_handler(self, _: TaskiqState) -> None: async def _broker_shutdown_handler(self, _: TaskiqState) -> None: await self.shutdown() - - -class EdgeCancellationWrapper: - class ListeningEvent(StartedListeningEvent): - def __init__(self, task_status: TaskStatus) -> None: - self.task_status = task_status - - async def set(self): - self.task_status.started() - - async def wait(self): - # Can ignore, won't execute further before task status is set - pass - - def __init__(self, backend: CancellationBackend, task: Callable, task_id: str): - self.backend = backend - self.task = task - self.task_id = task_id - - async def __call__(self, *args, **kwargs): - result = None - - listener_exception: Exception | None = None - task_exception: Exception | None = None - cancelled_by_request: bool = False - - async with anyio.create_task_group() as group: - - async def listen_for_cancellation( - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ): - nonlocal listener_exception, cancelled_by_request - - event = self.ListeningEvent(task_status) - try: - await self.backend.listen_for_cancellation(self.task_id, event) - except TaskCancellationException: - cancelled_by_request = True - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - listener_exception = e - finally: - group.cancel_scope.cancel() - - async def call_task(): - nonlocal result, task_exception - - try: - result = await self.task(*args, **kwargs) - except anyio.get_cancelled_exc_class(): - pass - except Exception as e: - task_exception = e - finally: - group.cancel_scope.cancel() - - # Listen before checking for cancellation in state holder - # so the message won't get lost in non-persistent queues - await group.start(listen_for_cancellation) - if await self.backend.is_cancelled(self.task_id): - cancelled_by_request = True - group.cancel_scope.cancel() - else: - group.start_soon(call_task) - - if task_exception is not None: - raise task_exception - elif cancelled_by_request: - raise TaskCancellationException() - elif listener_exception is not None: - raise listener_exception - else: - return result - - -class LevelCancellationWrapper: - class ListeningEvent(StartedListeningEvent): - def __init__(self) -> None: - self.event = asyncio.Event() - - async def set(self): - loop = asyncio.get_running_loop() - loop.call_soon_threadsafe(self.event.set) - - async def wait(self): - await self.event.wait() - - def __init__(self, backend: CancellationBackend, task: Callable, task_id: str): - self.backend = backend - self.task = task - self.task_id = task_id - - async def __call__(self, *args, **kwargs): - result = None - - listener_exception: Exception | None = None - task_exception: Exception | None = None - cancelled_by_request: bool = False - - async def listen_for_cancellation(event: StartedListeningEvent): - nonlocal listener_exception, cancelled_by_request - - try: - await self.backend.listen_for_cancellation(self.task_id, event) - except TaskCancellationException: - cancelled_by_request = True - raise - except asyncio.CancelledError: - raise - except Exception as e: - listener_exception = e - raise - - async def call_task(): - nonlocal result, task_exception - - try: - result = await self.task(*args, **kwargs) - except asyncio.CancelledError: - raise - except Exception as e: - task_exception = e - raise - - try: - async with asyncio.TaskGroup() as tg: - # Listen before checking for cancellation in state holder - # so the message won't get lost in non-persistent queues - event = self.ListeningEvent() - tg.create_task(listen_for_cancellation(event)) - await event.wait() - - if await self.backend.is_cancelled(self.task_id): - cancelled_by_request = True - raise StopTaskGroupException() - - task_task = asyncio.create_task(call_task()) - await task_task - if not task_task.cancelled(): - raise StopTaskGroupException() - except Exception: - # Exceptions are stored in local vars, can ignore - pass - - if task_exception is not None: - raise task_exception - elif cancelled_by_request: - raise TaskCancellationException() - elif listener_exception is not None: - raise listener_exception - else: - return result diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py index a1e7b49..2da1ec4 100644 --- a/src/taskiq_cancellation/backends/modular.py +++ b/src/taskiq_cancellation/backends/modular.py @@ -2,10 +2,10 @@ CancellationBackend, CancellationNotifier, CancellationStateHolder, + StartedListeningEvent, ) import anyio -from anyio.abc import TaskStatus class ModularCancellationBackend(CancellationBackend): @@ -34,11 +34,9 @@ async def cancel(self, task_id: str): group.start_soon(self.notifier.cancel, task_id) async def listen_for_cancellation( - self, task_id: str, started_listening_task_status: TaskStatus[None] + self, task_id: str, started_listening_event: StartedListeningEvent ): - await self.notifier.listen_for_cancellation( - task_id, started_listening_task_status - ) + await self.notifier.listen_for_cancellation(task_id, started_listening_event) async def startup(self) -> None: await super().startup() diff --git a/src/taskiq_cancellation/cancellation_handlers/__init__.py b/src/taskiq_cancellation/cancellation_handlers/__init__.py new file mode 100644 index 0000000..731a4a1 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/__init__.py @@ -0,0 +1,22 @@ +import sys + +from .cancellation_type import CancellationType +from .level import LevelCancellationHandler + +if sys.version_info >= (3, 11): + from .edge import EdgeCancellationHandler +else: + + class EdgeCancellationHandler: + def __init__(self, *args, **kwargs) -> None: + raise NotImplementedError( + "Edge cancellation is not supported for Python <3.11" + ) + + async def __call__(self, *args, **kwargs) -> None: + raise NotImplementedError( + "Edge cancellation is not supported for Python <3.11" + ) + + +__all__ = ["CancellationType", "LevelCancellationHandler", "EdgeCancellationHandler"] diff --git a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py new file mode 100644 index 0000000..0dd0d89 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py @@ -0,0 +1,6 @@ +import enum + + +class CancellationType(str, enum.Enum): + EDGE = "edge" + LEVEL = "level" diff --git a/src/taskiq_cancellation/cancellation_handlers/edge.py b/src/taskiq_cancellation/cancellation_handlers/edge.py new file mode 100644 index 0000000..b8cd29b --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge.py @@ -0,0 +1,110 @@ +# FIXME: bunch of issues because of Python 3.11+ exclusivity +# +# Edge cancellation handler is using asyncio.TaskGroup which was introduced in 3.11 and +# uses expect-star syntax that was also introduced in the same version +# - mypy has to ignore this file because it can't finish static parsing +# - ruff's python version has to be set at 3.11+ so it wouldn't complain +# +# I'm not sure how to mitigate these issues. Maybe this can be put in a separate module somehow +# and then integrated? Maybe this can be rewritten to not use TaskGroup (probably easier to do)? + +import logging +import asyncio +from typing import Callable, TYPE_CHECKING + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException +from taskiq_cancellation.utils import StopTaskGroupException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +class EdgeCancellationHandler: + class ListeningEvent(StartedListeningEvent): + def __init__(self) -> None: + self.event = asyncio.Event() + + async def set(self): + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self.event.set) + + async def wait(self): + await self.event.wait() + + def __init__(self, backend: "CancellationBackend", task: Callable, task_id: str): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args, **kwargs): + result = None + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async def listen_for_cancellation(event: StartedListeningEvent): + nonlocal listener_exception, cancelled_by_request + + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + raise + except asyncio.CancelledError: + raise + except Exception as e: + listener_exception = e + raise + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except asyncio.CancelledError: + raise + except Exception as e: + task_exception = e + raise + + try: + async with asyncio.TaskGroup() as tg: + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + event = self.ListeningEvent() + tg.create_task(listen_for_cancellation(event)) + await event.wait() + + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + raise StopTaskGroupException() + + task_task = asyncio.create_task(call_task()) + await task_task + if not task_task.cancelled(): + raise StopTaskGroupException() + except* StopTaskGroupException: + pass + except* Exception as exc_group: + uncaught_exceptions = list( + filter( + lambda e: e == task_exception or e == listener_exception, + exc_group.exceptions, + ) + ) + + if uncaught_exceptions: + logging.log(logging.ERROR, "Uncaught exception in TaskGroup") + for e in uncaught_exceptions: + logging.exception(e) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py new file mode 100644 index 0000000..a068f64 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/level.py @@ -0,0 +1,84 @@ +from typing import Callable, TYPE_CHECKING + +import anyio +from anyio.abc import TaskStatus + +from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent +from taskiq_cancellation.exceptions import TaskCancellationException + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +class LevelCancellationHandler: + class ListeningEvent(StartedListeningEvent): + def __init__(self, task_status: TaskStatus) -> None: + self.task_status = task_status + + async def set(self): + self.task_status.started() + + async def wait(self): + # Can ignore, won't execute further before task status is set + pass + + def __init__(self, backend: "CancellationBackend", task: Callable, task_id: str): + self.backend = backend + self.task = task + self.task_id = task_id + + async def __call__(self, *args, **kwargs): + result = None + + listener_exception: Exception | None = None + task_exception: Exception | None = None + cancelled_by_request: bool = False + + async with anyio.create_task_group() as group: + + async def listen_for_cancellation( + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ): + nonlocal listener_exception, cancelled_by_request + + event = self.ListeningEvent(task_status) + try: + await self.backend.listen_for_cancellation(self.task_id, event) + except TaskCancellationException: + cancelled_by_request = True + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + listener_exception = e + finally: + group.cancel_scope.cancel() + + async def call_task(): + nonlocal result, task_exception + + try: + result = await self.task(*args, **kwargs) + except anyio.get_cancelled_exc_class(): + pass + except Exception as e: + task_exception = e + finally: + group.cancel_scope.cancel() + + # Listen before checking for cancellation in state holder + # so the message won't get lost in non-persistent queues + await group.start(listen_for_cancellation) + if await self.backend.is_cancelled(self.task_id): + cancelled_by_request = True + group.cancel_scope.cancel() + else: + group.start_soon(call_task) + + if task_exception is not None: + raise task_exception + elif cancelled_by_request: + raise TaskCancellationException() + elif listener_exception is not None: + raise listener_exception + else: + return result diff --git a/src/taskiq_cancellation/notifiers/in_memory.py b/src/taskiq_cancellation/notifiers/in_memory.py index b03dbeb..1dfe59f 100644 --- a/src/taskiq_cancellation/notifiers/in_memory.py +++ b/src/taskiq_cancellation/notifiers/in_memory.py @@ -1,5 +1,6 @@ import time import asyncio +from typing import Union from taskiq_cancellation.message import CancellationMessage @@ -12,9 +13,14 @@ class InMemoryCancellationNotifier(QueueCancellationNotifier): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.messages: asyncio.Queue[CancellationMessage] = asyncio.Queue() + # In Python 3.9 queues must be created inside a running loop + # Source: https://stackoverflow.com/questions/53724665 + self.messages: Union[asyncio.Queue[CancellationMessage], None] = None async def cancel(self, task_id: str) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + timestamp = time.time() await self.messages.put( @@ -22,6 +28,9 @@ async def cancel(self, task_id: str) -> None: ) async def _listen(self, started_listening: asyncio.Event) -> None: + if self.messages is None: + self.messages = asyncio.Queue() + loop = asyncio.get_running_loop() loop.call_soon_threadsafe(started_listening.set) diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py index 28ef1ee..f18c083 100644 --- a/src/taskiq_cancellation/notifiers/queue.py +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -1,6 +1,7 @@ import abc import weakref import asyncio +from typing import Union from taskiq_cancellation.abc import CancellationNotifier, StartedListeningEvent from taskiq_cancellation.exceptions import TaskCancellationException @@ -18,7 +19,7 @@ class QueueCancellationNotifier(CancellationNotifier): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.listener_task: asyncio.Task | None = None + self.listener_task: Union[asyncio.Task, None] = None self.queues: weakref.WeakSet[asyncio.Queue[CancellationMessage]] = ( weakref.WeakSet() ) diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 5494673..1d9e2ba 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -25,80 +25,122 @@ def backend(broker: AsyncBroker): return InMemoryCancellationBackend().with_broker(broker) -@pytest.mark.asyncio -async def test_task_direct_call(broker: AsyncBroker, backend: CancellationBackend): - @broker.task - @backend.cancellable() - async def test_task(): - return True - - result = await test_task() - assert result - - @pytest.mark.asyncio @pytest.mark.parametrize( ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) ) -async def test_task_success( +async def test_task_direct_call( broker: AsyncBroker, backend: CancellationBackend, cancellation_type: CancellationType, ): - """Tests that cancellable task can successfully finish""" - @broker.task @backend.cancellable(cancellation_type=cancellation_type) async def test_task(): - await asyncio.sleep(0.1) + return True - await broker.startup() + result = await test_task() + assert result - task = await test_task.kiq() - result = await task.wait_result() - assert result.is_err is False - await broker.shutdown() +class TestTaskSuccess: + types = [CancellationType.LEVEL] + if sys.version_info >= (3, 11): + types.append(CancellationType.EDGE) + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_success( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully finish""" -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("cancellation_type"), (CancellationType.LEVEL, CancellationType.EDGE) -) -async def test_task_cancellation( - broker: AsyncBroker, - backend: CancellationBackend, - cancellation_type: CancellationType, -): - """Tests that cancellable task can successfully cancel""" + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + await asyncio.sleep(0.1) - started_event = asyncio.Event() + await broker.startup() - @broker.task - @backend.cancellable(cancellation_type=cancellation_type) - async def test_task(): - with pytest.raises(anyio.get_cancelled_exc_class()): - started_event.set() - await asyncio.sleep(0.2) + task = await test_task.kiq() + result = await task.wait_result() + assert result.is_err is False - await broker.startup() + await broker.shutdown() - task = await test_task.kiq() - assert await task.is_ready() is False - async with timeout(1): - await started_event.wait() - await backend.cancel(task.task_id) +class TestTaskCancellation: + types = [CancellationType.LEVEL] + if sys.version_info > (3, 11): + types.append(CancellationType.EDGE) - with pytest.raises(TaskCancellationException): - result = await task.wait_result() - result.raise_for_error() + @pytest.mark.asyncio + @pytest.mark.parametrize(("cancellation_type"), types) + @staticmethod + async def test_task_cancellation( + broker: AsyncBroker, + backend: CancellationBackend, + cancellation_type: CancellationType, + ): + """Tests that cancellable task can successfully cancel""" - await broker.shutdown() + started_event = asyncio.Event() + @broker.task + @backend.cancellable(cancellation_type=cancellation_type) + async def test_task(): + with pytest.raises(anyio.get_cancelled_exc_class()): + started_event.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + assert await task.is_ready() is False + + async with timeout(1): + await started_event.wait() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +class TestEdgeCancellation: + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) + async def test_non_supported( + self, broker: AsyncBroker, backend: CancellationBackend + ): + @broker.task + @backend.cancellable(cancellation_type=CancellationType.EDGE) + async def test_task(): + await asyncio.sleep(0.1) + + await broker.startup() + + task = await test_task.kiq() + + with pytest.raises(NotImplementedError): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() -class TestLevelCancellation: @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Edge cancellation is currently not supported for Python <3.11", + ) async def test_cancellation_interception( self, broker: AsyncBroker, backend: CancellationBackend ): @@ -114,7 +156,7 @@ async def test_cancellation_interception( task_started = asyncio.Event() @broker.task - @backend.cancellable(cancellation_type=CancellationType.LEVEL) + @backend.cancellable(cancellation_type=CancellationType.EDGE) async def test_task(): nonlocal cancelled_for_second_time @@ -144,7 +186,7 @@ async def test_task(): await broker.shutdown() -class TestEdgeCancellation: +class TestLevelCancellation: @pytest.mark.asyncio async def test_repeated_cancellation( self, broker: AsyncBroker, backend: CancellationBackend @@ -160,13 +202,13 @@ async def test_repeated_cancellation( started_event = asyncio.Event() @broker.task - @backend.cancellable(cancellation_type=CancellationType.EDGE) + @backend.cancellable(cancellation_type=CancellationType.LEVEL) async def test_task(): nonlocal cancelled_for_second_time try: started_event.set() - await asyncio.sleep(0.5) + await asyncio.sleep(1) except anyio.get_cancelled_exc_class(): # anyio cancels on any await after scope's cancellation try: From c7d06ec6c17a6bf87e30ccb957546b6d6e70979d Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Mon, 10 Nov 2025 18:33:11 +0300 Subject: [PATCH 13/18] feat: actually working type hinting --- pyproject.toml | 2 +- src/taskiq_cancellation/abc/backend.py | 29 ++++++++++------- .../cancellation_handlers/__init__.py | 14 ++------ .../{edge.py => edge_3_11.py} | 29 +++++++++++++---- .../edge_non_supported.py | 27 ++++++++++++++++ .../cancellation_handlers/level.py | 32 +++++++++++++++---- src/taskiq_cancellation/utils.py | 4 ++- 7 files changed, 98 insertions(+), 39 deletions(-) rename src/taskiq_cancellation/cancellation_handlers/{edge.py => edge_3_11.py} (83%) create mode 100644 src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py diff --git a/pyproject.toml b/pyproject.toml index 54d2eec..677fb5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ exclude = [ # Added so that "mypy ." would work "examples", # Contains Python 3.11+ code, has to be excluded. Runtime checks don't work. - "src/taskiq_cancellation/cancellation_handlers/edge.py", + "src/taskiq_cancellation/cancellation_handlers/edge_3_11.py", ] [tool.coverage.run] diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 13260c9..9c4abd7 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -1,12 +1,12 @@ import abc import sys import inspect -from typing import Callable, Annotated, Awaitable, overload, Optional, cast, Union +from typing import Callable, Annotated, overload, Optional, cast, Union if sys.version_info >= (3, 11): - from typing import Self, TypeAlias + from typing import Self, ParamSpec, TypeVar else: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self, ParamSpec, TypeVar from taskiq import Context, TaskiqDepends, AsyncBroker, TaskiqEvents, TaskiqState @@ -20,7 +20,8 @@ from .started_listening_event import StartedListeningEvent -AsyncCallable: TypeAlias = Callable[..., Awaitable] +Params = ParamSpec("Params") +Result = TypeVar("Result") class CancellationBackend(abc.ABC): @@ -131,13 +132,15 @@ def with_broker(self, broker: AsyncBroker) -> Self: return self @overload - def cancellable(self, cancellation_type: AsyncCallable) -> AsyncCallable: + def cancellable( + self, cancellation_type: Callable[Params, Result] + ) -> Callable[Params, Result]: pass @overload def cancellable( self, cancellation_type: Optional[CancellationType] = None - ) -> Callable[[AsyncCallable], AsyncCallable]: + ) -> Callable[[Callable[Params, Result]], Callable[Params, Result]]: pass def cancellable(self, cancellation_type=None): @@ -159,7 +162,9 @@ def cancellable(self, cancellation_type=None): defaults = {"cancellation_type": CancellationType.LEVEL} def make_decorator(cancellation_type: CancellationType): - def decorator(task: AsyncCallable) -> AsyncCallable: + def decorator( + task: Callable[Params, Result], / + ) -> Callable[Params, Result]: # Executor type depends on receiver configuration which we can't accessed in any way if not inspect.iscoroutinefunction(task): raise ValueError("Can't cancel synchronous function") @@ -168,9 +173,8 @@ def decorator(task: AsyncCallable) -> AsyncCallable: async def wrapper( *args, __taskiq_context: Annotated[Context, TaskiqDepends(Context)] = None, # type: ignore - # __taskiq_context: Annotated[Context, TaskiqDepends()], # type: ignore **kwargs, - ): + ) -> Result: if __taskiq_context is None: # Ran the function directly, without kiq return await task(*args, **kwargs) @@ -188,13 +192,14 @@ async def wrapper( f"Unknown cancellation type: {cancellation_type!r}" ) - return wrapper + # Wrapper adds a key-word only param with default value + casted_wrapper = cast(Callable[Params, Result], wrapper) + return casted_wrapper return decorator if callable(cancellation_type): - task = cast(Callable[..., Awaitable], cancellation_type) - return make_decorator(**defaults)(task) + return make_decorator(**defaults)(cancellation_type) else: return make_decorator( cancellation_type=cancellation_type or defaults["cancellation_type"] diff --git a/src/taskiq_cancellation/cancellation_handlers/__init__.py b/src/taskiq_cancellation/cancellation_handlers/__init__.py index 731a4a1..a228c8d 100644 --- a/src/taskiq_cancellation/cancellation_handlers/__init__.py +++ b/src/taskiq_cancellation/cancellation_handlers/__init__.py @@ -4,19 +4,9 @@ from .level import LevelCancellationHandler if sys.version_info >= (3, 11): - from .edge import EdgeCancellationHandler + from .edge_3_11 import EdgeCancellationHandler else: - - class EdgeCancellationHandler: - def __init__(self, *args, **kwargs) -> None: - raise NotImplementedError( - "Edge cancellation is not supported for Python <3.11" - ) - - async def __call__(self, *args, **kwargs) -> None: - raise NotImplementedError( - "Edge cancellation is not supported for Python <3.11" - ) + from .edge_non_supported import EdgeCancellationHandler __all__ = ["CancellationType", "LevelCancellationHandler", "EdgeCancellationHandler"] diff --git a/src/taskiq_cancellation/cancellation_handlers/edge.py b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py similarity index 83% rename from src/taskiq_cancellation/cancellation_handlers/edge.py rename to src/taskiq_cancellation/cancellation_handlers/edge_3_11.py index b8cd29b..441d66e 100644 --- a/src/taskiq_cancellation/cancellation_handlers/edge.py +++ b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py @@ -8,9 +8,16 @@ # I'm not sure how to mitigate these issues. Maybe this can be put in a separate module somehow # and then integrated? Maybe this can be rewritten to not use TaskGroup (probably easier to do)? +import sys import logging import asyncio -from typing import Callable, TYPE_CHECKING +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar from taskiq_cancellation.abc.started_listening_event import StartedListeningEvent from taskiq_cancellation.exceptions import TaskCancellationException @@ -20,7 +27,11 @@ from taskiq_cancellation.abc.backend import CancellationBackend -class EdgeCancellationHandler: +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): class ListeningEvent(StartedListeningEvent): def __init__(self) -> None: self.event = asyncio.Event() @@ -32,13 +43,19 @@ async def set(self): async def wait(self): await self.event.wait() - def __init__(self, backend: "CancellationBackend", task: Callable, task_id: str): + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): self.backend = backend self.task = task self.task_id = task_id - async def __call__(self, *args, **kwargs): - result = None + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Result = None + # type: ignore listener_exception: Exception | None = None task_exception: Exception | None = None @@ -96,7 +113,7 @@ async def call_task(): ) if uncaught_exceptions: - logging.log(logging.ERROR, "Uncaught exception in TaskGroup") + logging.log(logging.ERROR, "Uncaught exceptions in TaskGroup") for e in uncaught_exceptions: logging.exception(e) diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py new file mode 100644 index 0000000..7184105 --- /dev/null +++ b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py @@ -0,0 +1,27 @@ +import sys +from typing import Callable, TYPE_CHECKING, Coroutine, Generic, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar + +if TYPE_CHECKING: + from taskiq_cancellation.abc.backend import CancellationBackend + + +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class EdgeCancellationHandler(Generic[Params, Result]): + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ) -> None: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") + + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + raise NotImplementedError("Edge cancellation is not supported for Python <3.11") diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py index a068f64..55f740b 100644 --- a/src/taskiq_cancellation/cancellation_handlers/level.py +++ b/src/taskiq_cancellation/cancellation_handlers/level.py @@ -1,4 +1,11 @@ -from typing import Callable, TYPE_CHECKING +import sys +from collections.abc import Coroutine +from typing import Callable, TYPE_CHECKING, Generic, Union, cast, Any + +if sys.version_info >= (3, 11): + from typing import ParamSpec, TypeVar +else: + from typing_extensions import ParamSpec, TypeVar import anyio from anyio.abc import TaskStatus @@ -10,7 +17,11 @@ from taskiq_cancellation.abc.backend import CancellationBackend -class LevelCancellationHandler: +Params = ParamSpec("Params") +Result = TypeVar("Result") + + +class LevelCancellationHandler(Generic[Params, Result]): class ListeningEvent(StartedListeningEvent): def __init__(self, task_status: TaskStatus) -> None: self.task_status = task_status @@ -22,16 +33,21 @@ async def wait(self): # Can ignore, won't execute further before task status is set pass - def __init__(self, backend: "CancellationBackend", task: Callable, task_id: str): + def __init__( + self, + backend: "CancellationBackend", + task: Callable[Params, Coroutine[Any, Any, Result]], + task_id: str, + ): self.backend = backend self.task = task self.task_id = task_id - async def __call__(self, *args, **kwargs): - result = None + async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: + result: Union[Result, None] = None - listener_exception: Exception | None = None - task_exception: Exception | None = None + listener_exception: Union[Exception, None] = None + task_exception: Union[Exception, None] = None cancelled_by_request: bool = False async with anyio.create_task_group() as group: @@ -81,4 +97,6 @@ async def call_task(): elif listener_exception is not None: raise listener_exception else: + # If the task is finished, it is definitely not None + result = cast(Result, result) return result diff --git a/src/taskiq_cancellation/utils.py b/src/taskiq_cancellation/utils.py index 9a9bd4f..a128266 100644 --- a/src/taskiq_cancellation/utils.py +++ b/src/taskiq_cancellation/utils.py @@ -5,7 +5,9 @@ from collections import OrderedDict -def combines(wrapped, add_var_parameters=False): +def combines( + wrapped: typing.Callable, add_var_parameters: bool = False +) -> typing.Callable: """ Combines wrapped and wrapper functions signatures and type hints From bc84268cc17c4cc7de3a2875810b5c491ea6f916 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Mon, 10 Nov 2025 18:33:39 +0300 Subject: [PATCH 14/18] test: add sync function cancellable test and docstrings --- tests/test_backend.py | 18 ++++++++++++++++++ tests/test_cancellation.py | 4 ++++ 2 files changed, 22 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 4c4b8c2..459492d 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -12,6 +12,8 @@ def backend(): @pytest.mark.asyncio async def test_decorator_without_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works without parentesis""" + @backend.cancellable async def test_task(): pass @@ -23,6 +25,8 @@ async def test_task(): @pytest.mark.asyncio async def test_decorator_with_parentesis(backend: CancellationBackend): + """Tests that cancellable decorator works with parentesis""" + @backend.cancellable() async def test_task(): pass @@ -30,3 +34,17 @@ async def test_task(): task = test_task() assert inspect.iscoroutine(task) await task + + +@pytest.mark.asyncio +async def test_decorator_with_sync_function(backend: CancellationBackend): + """ + Tests that cancellable decorator doesn't work with synchronous functions + + To launch a synchronous function we need to know how to do it and only Receiver knows that + """ + + with pytest.raises(ValueError): + @backend.cancellable + def test_task(): + pass diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 1d9e2ba..892e988 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -34,6 +34,8 @@ async def test_task_direct_call( backend: CancellationBackend, cancellation_type: CancellationType, ): + """Tests that cancellable function can be called directly""" + @broker.task @backend.cancellable(cancellation_type=cancellation_type) async def test_task(): @@ -121,6 +123,8 @@ class TestEdgeCancellation: async def test_non_supported( self, broker: AsyncBroker, backend: CancellationBackend ): + """Tests that edge cancellation raises NotImplementedError in Python <3.11""" + @broker.task @backend.cancellable(cancellation_type=CancellationType.EDGE) async def test_task(): From 5352d11992ba815b086b164e6a156f6ea58daec9 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Tue, 11 Nov 2025 00:36:53 +0300 Subject: [PATCH 15/18] fix: close connections in redis and aiopika integrations (how did I miss this) --- src/taskiq_cancellation/notifiers/aiopika.py | 71 ++++++++++--------- src/taskiq_cancellation/notifiers/redis.py | 5 ++ .../state_holders/redis.py | 5 ++ 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py index 8e55ab8..205ede2 100644 --- a/src/taskiq_cancellation/notifiers/aiopika.py +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -19,43 +19,48 @@ def __init__(self, url: str, **kwargs): async def cancel(self, task_id: str) -> None: timestamp = time.time() + connection = await aio_pika.connect_robust(self.url) - channel = await connection.channel() - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) + async with connection: + channel = await connection.channel() - await exchange.publish( - aio_pika.Message( - body=self.serializer.dumpb( - model_dump( - CancellationMessage(task_id=task_id, timestamp=timestamp) + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + + await exchange.publish( + aio_pika.Message( + body=self.serializer.dumpb( + model_dump( + CancellationMessage(task_id=task_id, timestamp=timestamp) + ) ) - ) - ), - routing_key="", - ) + ), + routing_key="", + ) async def _listen(self, started_listening: asyncio.Event): connection = await aio_pika.connect_robust(self.url) - channel = await connection.channel() - - exchange = await channel.declare_exchange( - self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True - ) - queue = await channel.declare_queue(exclusive=True, auto_delete=True) - await queue.bind(exchange) - - loop = asyncio.get_running_loop() - loop.call_soon_threadsafe(started_listening.set) - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - cancellation_message = model_validate( - CancellationMessage, self.serializer.loadb(message.body) - ) - - for queue in self.queues: - await queue.put(cancellation_message) - await message.ack() + + async with connection: + channel = await connection.channel() + + exchange = await channel.declare_exchange( + self.EXCHANGE_NAME, aio_pika.ExchangeType.FANOUT, durable=True + ) + queue = await channel.declare_queue(exclusive=True, auto_delete=True) + await queue.bind(exchange) + + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(started_listening.set) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + cancellation_message = model_validate( + CancellationMessage, self.serializer.loadb(message.body) + ) + + for queue in self.queues: + await queue.put(cancellation_message) + await message.ack() diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py index 2efa7c2..51b29bc 100644 --- a/src/taskiq_cancellation/notifiers/redis.py +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -54,3 +54,8 @@ async def _listen(self, started_listening: asyncio.Event): ) for queue in self.queues: await queue.put(cancellation_message) + + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + \ No newline at end of file diff --git a/src/taskiq_cancellation/state_holders/redis.py b/src/taskiq_cancellation/state_holders/redis.py index 262f121..be86f08 100644 --- a/src/taskiq_cancellation/state_holders/redis.py +++ b/src/taskiq_cancellation/state_holders/redis.py @@ -17,5 +17,10 @@ async def is_cancelled(self, task_id: str) -> bool: response = await conn.get(self._task_key(task_id)) return bool(response) + async def shutdown(self) -> None: + await super().shutdown() + await self.connection_pool.aclose() + def _task_key(self, task_id: str) -> str: return f"__cancellation_status_{task_id}" + From 9b1d9f79e4aa4ea8b8eae3467c494078ff6a4d96 Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Tue, 11 Nov 2025 00:40:25 +0300 Subject: [PATCH 16/18] docs: more docstrings --- src/taskiq_cancellation/abc/backend.py | 3 ++- src/taskiq_cancellation/abc/notifier.py | 7 +++---- .../abc/started_listening_event.py | 11 +++++++++++ src/taskiq_cancellation/backends/in_memory.py | 6 ++++++ .../cancellation_handlers/cancellation_type.py | 2 ++ .../cancellation_handlers/edge_3_11.py | 13 +++++++++++-- .../cancellation_handlers/edge_non_supported.py | 7 +++++++ .../cancellation_handlers/level.py | 8 ++++++++ src/taskiq_cancellation/notifiers/aiopika.py | 2 ++ src/taskiq_cancellation/notifiers/queue.py | 1 + src/taskiq_cancellation/notifiers/redis.py | 2 ++ 11 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/taskiq_cancellation/abc/backend.py b/src/taskiq_cancellation/abc/backend.py index 9c4abd7..6f20b5e 100644 --- a/src/taskiq_cancellation/abc/backend.py +++ b/src/taskiq_cancellation/abc/backend.py @@ -155,7 +155,8 @@ def cancellable(self, cancellation_type=None): - Raises :ref:`TaskCancellationException` if listener task receives cancellation message - If listener task raises an exception, task is cancelled and exception is propogated upwards - :param task: Task function to wrap + :param cancellation_type: type of cancellation used + :type cancellation_type: CancellationType :returns: Cancellable task function """ diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index 938618a..eb0718d 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -39,12 +39,11 @@ async def listen_for_cancellation( receives :ref:`CancellationMessage` with same id as *task_id*. This function is used in :func:`cancellable` decorator of :ref:`ModularCancellationBackend`. - Call `started_listening_task_status.started()` when the listener is ready - to receive messages. + Call `started_listening_event.set()` when the listener is ready to receive messages. :param task_id: id of task that will be listened for :type task_id: str - :param started_listening_task_status: - :type started_listening_task_status: anyio.abc.TaskStatus + :param started_listening_event: "listener started listening" confirmation event + :type started_listening_event: StartedListeningEvent """ pass diff --git a/src/taskiq_cancellation/abc/started_listening_event.py b/src/taskiq_cancellation/abc/started_listening_event.py index 28a51eb..4e43acf 100644 --- a/src/taskiq_cancellation/abc/started_listening_event.py +++ b/src/taskiq_cancellation/abc/started_listening_event.py @@ -2,10 +2,21 @@ class StartedListeningEvent(abc.ABC): + """ + A confirmation event for listeners to mark that they started listening to messages. API is + similar to :ref:`asyncio.Event`. + + This is needed for different cancellation types: + - Level cancellation uses :ref:`anyio.abc.TaskStatus` + - Edge cancellation uses :ref:`asyncio.Event` + """ + @abc.abstractmethod async def set(self): + """Sets the event""" pass @abc.abstractmethod async def wait(self): + """Waits for the event to be set""" pass diff --git a/src/taskiq_cancellation/backends/in_memory.py b/src/taskiq_cancellation/backends/in_memory.py index 6c8f172..795ac6e 100644 --- a/src/taskiq_cancellation/backends/in_memory.py +++ b/src/taskiq_cancellation/backends/in_memory.py @@ -5,6 +5,12 @@ class InMemoryCancellationBackend(ModularCancellationBackend): + """ + Cancellation backend that stores state and notifications in memory + + Useful for testing purposes + """ + def __init__(self, **kwargs): super().__init__( state_holder=InMemoryCancellationStateHolder(**kwargs), diff --git a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py index 0dd0d89..9714e67 100644 --- a/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py +++ b/src/taskiq_cancellation/cancellation_handlers/cancellation_type.py @@ -2,5 +2,7 @@ class CancellationType(str, enum.Enum): + """Type of cancellation used by the backend""" + EDGE = "edge" LEVEL = "level" diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py index 441d66e..4bf66b5 100644 --- a/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py +++ b/src/taskiq_cancellation/cancellation_handlers/edge_3_11.py @@ -32,6 +32,16 @@ class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. That means :ref:`asyncio.CancelledError` is + raised only once for the task. + Docs: https://docs.python.org/3/library/asyncio-task.html#task-cancellation + + Currently is supported in Python 3.11+ due to using :ref:`asyncio.TaskGroup`. + """ + class ListeningEvent(StartedListeningEvent): def __init__(self) -> None: self.event = asyncio.Event() @@ -54,8 +64,7 @@ def __init__( self.task_id = task_id async def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> Result: - result: Result = None - # type: ignore + result: Result = None # type: ignore listener_exception: Exception | None = None task_exception: Exception | None = None diff --git a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py index 7184105..03d06fd 100644 --- a/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py +++ b/src/taskiq_cancellation/cancellation_handlers/edge_non_supported.py @@ -15,6 +15,13 @@ class EdgeCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses edge cancellation provided by asyncio. Currently is supported in Python 3.11+ due + to using :ref:`asyncio.TaskGroup`. + """ + def __init__( self, backend: "CancellationBackend", diff --git a/src/taskiq_cancellation/cancellation_handlers/level.py b/src/taskiq_cancellation/cancellation_handlers/level.py index 55f740b..2cfe71a 100644 --- a/src/taskiq_cancellation/cancellation_handlers/level.py +++ b/src/taskiq_cancellation/cancellation_handlers/level.py @@ -22,6 +22,14 @@ class LevelCancellationHandler(Generic[Params, Result]): + """ + Wrapper around a task function that handles cancellation + + Uses level cancellation provided by anyio. That means cancellation exception is raised + on every await in the coroutine. + Docs: https://anyio.readthedocs.io/en/stable/cancellation.html#differences-between-asyncio-and-anyio-cancellation-semantics + """ + class ListeningEvent(StartedListeningEvent): def __init__(self, task_status: TaskStatus) -> None: self.task_status = task_status diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py index 205ede2..945a445 100644 --- a/src/taskiq_cancellation/notifiers/aiopika.py +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -10,6 +10,8 @@ class AioPikaNotifier(QueueCancellationNotifier): + """Notifier for RabbitMQ using aio-pika""" + EXCHANGE_NAME = "__taskiq_cancellation" def __init__(self, url: str, **kwargs): diff --git a/src/taskiq_cancellation/notifiers/queue.py b/src/taskiq_cancellation/notifiers/queue.py index f18c083..16ea174 100644 --- a/src/taskiq_cancellation/notifiers/queue.py +++ b/src/taskiq_cancellation/notifiers/queue.py @@ -28,6 +28,7 @@ def __init__(self, **kwargs) -> None: async def shutdown(self) -> None: if self.listener_task is not None: self.listener_task.cancel() + await asyncio.wait([self.listener_task]) async def listen_for_cancellation( self, task_id: str, started_listening_event: StartedListeningEvent diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py index 51b29bc..5d7906e 100644 --- a/src/taskiq_cancellation/notifiers/redis.py +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -10,6 +10,8 @@ class PubSubCancellationNotifier(QueueCancellationNotifier): + """Cancellation notifier using Redis pub/sub""" + CHANNEL_NAME = "__taskiq_cancellation_notifications" def __init__(self, url: str, **kwargs) -> None: From d6adb15f39199fefa503305d770d5506c9afb36d Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Tue, 11 Nov 2025 15:04:21 +0300 Subject: [PATCH 17/18] test: intergration tests for redis and aiopika --- .github/workflows/run_tests.yaml | 38 +++++++- docker-compose-tests.yml | 20 +++++ pyproject.toml | 2 +- tests/integration/__init__.py | 0 tests/integration/aiopika/__init__.py | 0 tests/integration/aiopika/conftest.py | 14 +++ tests/integration/aiopika/test_notifier.py | 15 ++++ tests/integration/common/cancellations.py | 94 ++++++++++++++++++++ tests/integration/redis/__init__.py | 0 tests/integration/redis/conftest.py | 14 +++ tests/integration/redis/test_backend.py | 15 ++++ tests/integration/redis/test_pubsub.py | 15 ++++ tests/integration/redis/test_state_holder.py | 15 ++++ tests/unit/__init__.py | 0 tests/{ => unit}/test_backend.py | 0 tests/{ => unit}/test_cancellation.py | 0 16 files changed, 237 insertions(+), 5 deletions(-) create mode 100644 docker-compose-tests.yml create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/aiopika/__init__.py create mode 100644 tests/integration/aiopika/conftest.py create mode 100644 tests/integration/aiopika/test_notifier.py create mode 100644 tests/integration/common/cancellations.py create mode 100644 tests/integration/redis/__init__.py create mode 100644 tests/integration/redis/conftest.py create mode 100644 tests/integration/redis/test_backend.py create mode 100644 tests/integration/redis/test_pubsub.py create mode 100644 tests/integration/redis/test_state_holder.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/test_backend.py (100%) rename tests/{ => unit}/test_cancellation.py (100%) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index fdd04d6..ca80840 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -2,10 +2,10 @@ name: Testing on: pull_request: - branches: [develop] + branches: [develop, main] jobs: - run-tests: + run-unit-tests: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] @@ -13,6 +13,7 @@ jobs: fail-fast: false runs-on: ${{ matrix.os }} + steps: - name: Checkout uses: actions/checkout@v5 @@ -31,5 +32,34 @@ jobs: - name: Install modules run: uv sync - - name: Run tests for Python ${{ matrix.python-version }} - run: uv run pytest + - name: Run unit tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/unit + + run-integration-tests: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v7 + + - name: Install modules + run: uv sync --extra redis --extra aiopika + + - name: Setup Docker containers + run: docker compose -f ./docker-compose-tests.yml up --wait + + - name: Run integration tests for Python ${{ matrix.python-version }} + run: uv run pytest tests/integration diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml new file mode 100644 index 0000000..ffa0fbd --- /dev/null +++ b/docker-compose-tests.yml @@ -0,0 +1,20 @@ +services: + redis: + image: redis:latest + ports: + - "6379:6379" + + rabbitmq: + image: rabbitmq:latest + environment: + - RABBITMQ_DEFAULT_USER=guest + - RABBITMQ_DEFAULT_PASSWORD=guest + hostname: localhost + ports: + - "5672:5672" + healthcheck: + test: "rabbitmq-diagnostics check_running -q" + interval: 5s + timeout: 5s + retries: 10 + start_period: 5s \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 677fb5b..1df4f56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ ] [project.optional-dependencies] -redis = ["redis~=3.0"] +redis = ["redis~=6.0"] aiopika = ["aio_pika"] [project.urls] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/__init__.py b/tests/integration/aiopika/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/aiopika/conftest.py b/tests/integration/aiopika/conftest.py new file mode 100644 index 0000000..53e27dc --- /dev/null +++ b/tests/integration/aiopika/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def rabbitmq_url(): + return os.environ.get("TEST_RABBITMQ_URL", "amqp://guest:guest@localhost:5672") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/aiopika/test_notifier.py b/tests/integration/aiopika/test_notifier.py new file mode 100644 index 0000000..11e6553 --- /dev/null +++ b/tests/integration/aiopika/test_notifier.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.aiopika import AioPikaNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def notifier(rabbitmq_url): + return AioPikaNotifier(url=rabbitmq_url) + + +@pytest.mark.asyncio +async def test_cancellation(notifier: AioPikaNotifier): + await run_notifier_cancellation_test(notifier) diff --git a/tests/integration/common/cancellations.py b/tests/integration/common/cancellations.py new file mode 100644 index 0000000..e495604 --- /dev/null +++ b/tests/integration/common/cancellations.py @@ -0,0 +1,94 @@ +import sys +import uuid +import pytest +import asyncio + +from taskiq import InMemoryBroker + +from taskiq_cancellation.abc import CancellationBackend, CancellationNotifier, CancellationStateHolder +from taskiq_cancellation.backends.modular import ModularCancellationBackend +from taskiq_cancellation.notifiers.null import NullCancellationNotifier +from taskiq_cancellation.state_holders.null import NullCancellationStateHolder +from taskiq_cancellation.exceptions import TaskCancellationException + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + + + +async def run_backend_cancellation_test(backend: CancellationBackend): + broker = InMemoryBroker() + backend = backend.with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + pass + + await broker.startup() + + task = await test_task.kiq() + await backend.cancel(task.task_id) + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_notifier_cancellation_test(notifier: CancellationNotifier): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + NullCancellationStateHolder(), + notifier + ).with_broker(broker) + + task_started = asyncio.Event() + + @broker.task + @backend.cancellable + async def test_task(): + task_started.set() + await asyncio.sleep(0.2) + + await broker.startup() + + task = await test_task.kiq() + async with timeout(1): + await task_started.wait() + + await backend.cancel(task.task_id) + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() + + +async def run_state_holder_cancellation_test(state_holder: CancellationStateHolder): + broker = InMemoryBroker() + backend = ModularCancellationBackend( + state_holder, + NullCancellationNotifier() + ).with_broker(broker) + + @broker.task + @backend.cancellable + async def test_task(): + # This task is supposed to never start + assert False + + await broker.startup() + + task_id = str(uuid.uuid4()) + await backend.cancel(task_id) + task = await test_task.kicker().with_task_id(task_id).kiq() + + with pytest.raises(TaskCancellationException): + result = await task.wait_result() + result.raise_for_error() + + await broker.shutdown() diff --git a/tests/integration/redis/__init__.py b/tests/integration/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/redis/conftest.py b/tests/integration/redis/conftest.py new file mode 100644 index 0000000..1823f42 --- /dev/null +++ b/tests/integration/redis/conftest.py @@ -0,0 +1,14 @@ +import os +import pytest + +from taskiq import InMemoryBroker + + +@pytest.fixture +def redis_url(): + return os.environ.get("TEST_REDIS_URL", "redis://localhost:6379") + + +@pytest.fixture +def broker(): + return InMemoryBroker() diff --git a/tests/integration/redis/test_backend.py b/tests/integration/redis/test_backend.py new file mode 100644 index 0000000..a5bde89 --- /dev/null +++ b/tests/integration/redis/test_backend.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.backends.redis import RedisCancellationBackend + +from ..common.cancellations import run_backend_cancellation_test + + +@pytest.fixture +def redis_backend(redis_url, broker): + return RedisCancellationBackend(url=redis_url).with_broker(broker) + + +@pytest.mark.asyncio +async def test_cancellation(redis_backend: RedisCancellationBackend): + await run_backend_cancellation_test(redis_backend) diff --git a/tests/integration/redis/test_pubsub.py b/tests/integration/redis/test_pubsub.py new file mode 100644 index 0000000..b49720c --- /dev/null +++ b/tests/integration/redis/test_pubsub.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.notifiers.redis import PubSubCancellationNotifier + +from ..common.cancellations import run_notifier_cancellation_test + + +@pytest.fixture +def pubsub_notifier(redis_url): + return PubSubCancellationNotifier(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(pubsub_notifier: PubSubCancellationNotifier): + await run_notifier_cancellation_test(pubsub_notifier) diff --git a/tests/integration/redis/test_state_holder.py b/tests/integration/redis/test_state_holder.py new file mode 100644 index 0000000..f8be37b --- /dev/null +++ b/tests/integration/redis/test_state_holder.py @@ -0,0 +1,15 @@ +import pytest + +from taskiq_cancellation.state_holders.redis import RedisCancellationStateHolder + +from ..common.cancellations import run_state_holder_cancellation_test + + +@pytest.fixture +def state_holder(redis_url): + return RedisCancellationStateHolder(url=redis_url) + + +@pytest.mark.asyncio +async def test_cancellation(state_holder: RedisCancellationStateHolder): + await run_state_holder_cancellation_test(state_holder) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_backend.py b/tests/unit/test_backend.py similarity index 100% rename from tests/test_backend.py rename to tests/unit/test_backend.py diff --git a/tests/test_cancellation.py b/tests/unit/test_cancellation.py similarity index 100% rename from tests/test_cancellation.py rename to tests/unit/test_cancellation.py From 3ba287f1233c151e902dc18e9cc1f1e08b3cc50c Mon Sep 17 00:00:00 2001 From: Alexander Starikov Date: Tue, 11 Nov 2025 15:50:20 +0300 Subject: [PATCH 18/18] feat: better serializer api and redis/aiopika notifier inits --- src/taskiq_cancellation/abc/notifier.py | 22 ++++++++++++++++++-- src/taskiq_cancellation/backends/modular.py | 21 +++++++++++++++++++ src/taskiq_cancellation/notifiers/aiopika.py | 20 ++++++++++++------ src/taskiq_cancellation/notifiers/redis.py | 13 +++++++++--- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/taskiq_cancellation/abc/notifier.py b/src/taskiq_cancellation/abc/notifier.py index eb0718d..60b1d14 100644 --- a/src/taskiq_cancellation/abc/notifier.py +++ b/src/taskiq_cancellation/abc/notifier.py @@ -1,5 +1,11 @@ +import sys import abc +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + from taskiq.abc.serializer import TaskiqSerializer from taskiq.serializers import JSONSerializer @@ -9,8 +15,8 @@ class CancellationNotifier(abc.ABC): """Receives cancellation messages and notifies listeners of these messages""" - def __init__(self, serializer: TaskiqSerializer = JSONSerializer()): - self.serializer = serializer + def __init__(self): + self.serializer = JSONSerializer() async def startup(self) -> None: """Starts up cancellation notifier""" @@ -47,3 +53,15 @@ async def listen_for_cancellation( :type started_listening_event: StartedListeningEvent """ pass + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.serializer = serializer + + return self diff --git a/src/taskiq_cancellation/backends/modular.py b/src/taskiq_cancellation/backends/modular.py index 2da1ec4..c65c00c 100644 --- a/src/taskiq_cancellation/backends/modular.py +++ b/src/taskiq_cancellation/backends/modular.py @@ -1,3 +1,12 @@ +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from taskiq.abc.serializer import TaskiqSerializer + from taskiq_cancellation.abc import ( CancellationBackend, CancellationNotifier, @@ -47,3 +56,15 @@ async def shutdown(self) -> None: await super().shutdown() await self.state_holder.shutdown() await self.notifier.shutdown() + + def with_serializer(self, serializer: TaskiqSerializer) -> Self: + """ + Sets a serializer to be used by the notifier + + :param serializer: serializer for cancellation messages + :type serializer: TaskiqSerializer + :return: self + """ + self.notifier = self.notifier.with_serializer(serializer) + + return self diff --git a/src/taskiq_cancellation/notifiers/aiopika.py b/src/taskiq_cancellation/notifiers/aiopika.py index 945a445..2432e56 100644 --- a/src/taskiq_cancellation/notifiers/aiopika.py +++ b/src/taskiq_cancellation/notifiers/aiopika.py @@ -14,15 +14,23 @@ class AioPikaNotifier(QueueCancellationNotifier): EXCHANGE_NAME = "__taskiq_cancellation" - def __init__(self, url: str, **kwargs): - super().__init__(**kwargs) + def __init__(self, url: str, **connection_kwargs): + """ + Creates AioPika notifier + + :param url: url to rabbitmq + :type url: str + :param connection_kwargs: arguments for :ref:`aio_pika.connect_robust` + """ + super().__init__() self.url: str = url + self.connection_kwargs = connection_kwargs async def cancel(self, task_id: str) -> None: timestamp = time.time() - connection = await aio_pika.connect_robust(self.url) + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) async with connection: channel = await connection.channel() @@ -43,7 +51,7 @@ async def cancel(self, task_id: str) -> None: ) async def _listen(self, started_listening: asyncio.Event): - connection = await aio_pika.connect_robust(self.url) + connection = await aio_pika.connect_robust(self.url, **self.connection_kwargs) async with connection: channel = await connection.channel() @@ -63,6 +71,6 @@ async def _listen(self, started_listening: asyncio.Event): CancellationMessage, self.serializer.loadb(message.body) ) - for queue in self.queues: - await queue.put(cancellation_message) + for subscriber_queue in self.queues: + await subscriber_queue.put(cancellation_message) await message.ack() diff --git a/src/taskiq_cancellation/notifiers/redis.py b/src/taskiq_cancellation/notifiers/redis.py index 5d7906e..fe7746e 100644 --- a/src/taskiq_cancellation/notifiers/redis.py +++ b/src/taskiq_cancellation/notifiers/redis.py @@ -14,10 +14,17 @@ class PubSubCancellationNotifier(QueueCancellationNotifier): CHANNEL_NAME = "__taskiq_cancellation_notifications" - def __init__(self, url: str, **kwargs) -> None: - super().__init__(**kwargs) + def __init__(self, url: str, **connection_kwargs) -> None: + """ + Creates AioPika notifier - self.connection_pool = redis.BlockingConnectionPool.from_url(url, **kwargs) + :param url: url to redis + :type url: str + :param connection_kwargs: arguments for :ref:`redis.BlockingConnectionPool.from_url` + """ + super().__init__() + + self.connection_pool = redis.BlockingConnectionPool.from_url(url, **connection_kwargs) async def cancel(self, task_id: str) -> None: timestamp = time.time()