diff --git a/stompman/listening_events.py b/stompman/listening_events.py index c025c63..f94e3cf 100644 --- a/stompman/listening_events.py +++ b/stompman/listening_events.py @@ -1,5 +1,6 @@ +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Self from stompman.frames import ( AckFrame, @@ -36,6 +37,24 @@ async def nack(self) -> None: ) ) + async def with_auto_ack( + self, + awaitable: Awaitable[None], + *, + on_suppressed_exception: Callable[[Exception, Self], Any], + supressed_exception_classes: tuple[type[Exception], ...] = (Exception,), + ) -> None: + called_nack = False + try: + await awaitable + except supressed_exception_classes as exception: + await self.nack() + called_nack = True + on_suppressed_exception(exception, self) + finally: + if not called_nack: + await self.ack() + @dataclass class ErrorEvent: diff --git a/tests/test_client.py b/tests/test_client.py index 736e12b..6890dc2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -346,6 +346,69 @@ async def test_ack_nack() -> None: assert_frames_between_lifespan_match(collected_frames, [message_frame, nack_frame, ack_frame]) +def get_mocked_message_event() -> tuple[MessageEvent, mock.AsyncMock, mock.AsyncMock, mock.Mock]: + ack_mock, nack_mock, on_suppressed_exception_mock = mock.AsyncMock(), mock.AsyncMock(), mock.Mock() + + class CustomMessageEvent(MessageEvent): + ack = ack_mock + nack = nack_mock + + return ( + CustomMessageEvent( + _frame=MessageFrame( + headers={"destination": "destination", "message-id": "message-id", "subscription": "subscription"}, + body=b"", + ), + _client=mock.Mock(), + ), + ack_mock, + nack_mock, + on_suppressed_exception_mock, + ) + + +async def test_message_event_with_auto_ack_nack() -> None: + event, ack, nack, on_suppressed_exception = get_mocked_message_event() + exception = RuntimeError() + + async def raises_runtime_error() -> None: # noqa: RUF029 + raise exception + + await event.with_auto_ack( + raises_runtime_error(), + supressed_exception_classes=(Exception,), + on_suppressed_exception=on_suppressed_exception, + ) + + ack.assert_not_called() + nack.assert_called_once_with() + on_suppressed_exception.assert_called_once_with(exception, event) + + +async def test_message_event_with_auto_ack_ack_raises() -> None: + event, ack, nack, on_suppressed_exception = get_mocked_message_event() + + async def func() -> None: # noqa: RUF029 + raise Exception # noqa: TRY002 + + with suppress(Exception): + await event.with_auto_ack( + func(), supressed_exception_classes=(RuntimeError,), on_suppressed_exception=on_suppressed_exception + ) + + ack.assert_called_once_with() + nack.assert_not_called() + on_suppressed_exception.assert_not_called() + + +async def test_message_event_with_auto_ack_ack_ok() -> None: + event, ack, nack, on_suppressed_exception = get_mocked_message_event() + await event.with_auto_ack(mock.AsyncMock()(), on_suppressed_exception=on_suppressed_exception) + ack.assert_called_once_with() + nack.assert_not_called() + on_suppressed_exception.assert_not_called() + + async def test_send_message_and_enter_transaction_ok(monkeypatch: pytest.MonkeyPatch) -> None: body = b"hello" destination = "/queue/test"