Skip to content

Commit 0b66575

Browse files
authored
Add Client.subscribe_with_manual_ack() (#105)
1 parent 57f1885 commit 0b66575

File tree

6 files changed

+128
-39
lines changed

6 files changed

+128
-39
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ You can pass custom headers to `client.subscribe()`:
108108
await client.subscribe("DLQ", handle_message_from_dlq, ack="client", headers={"selector": "location = 'Europe'"}, on_suppressed_exception=print)
109109
```
110110

111+
#### Handling ACK/NACKs yourself
112+
113+
If you want to send ACK and NACK frames yourself, you can use `client.subscribe_with_manual_ack()`:
114+
115+
```python
116+
async def handle_message_from_dlq(message_frame: stompman.AckableMessageFrame) -> None:
117+
print(message_frame.body)
118+
await message_frame.ack()
119+
120+
await client.subscribe_with_manual_ack("DLQ", handle_message_from_dlq, ack="client")
121+
```
122+
123+
Note that this way exceptions won't be suppressed automatically.
124+
111125
### Cleaning Up
112126

113127
stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, or `client.begin()`, the necessary frames will be sent to the server.

packages/faststream-stomp/faststream_stomp/subscriber.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
self.headers = headers
4040
self.on_suppressed_exception = on_suppressed_exception
4141
self.suppressed_exception_classes = suppressed_exception_classes
42-
self._subscription: stompman.Subscription | None = None
42+
self._subscription: stompman.AutoAckSubscription | None = None
4343

4444
super().__init__(
4545
no_ack=self.ack == "auto",

packages/stompman/stompman/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@
3131
UnsubscribeFrame,
3232
)
3333
from stompman.serde import FrameParser, dump_frame
34-
from stompman.subscription import Subscription
34+
from stompman.subscription import AckableMessageFrame, AutoAckSubscription, ManualAckSubscription
3535
from stompman.transaction import Transaction
3636

3737
__all__ = [
3838
"AbortFrame",
3939
"AckFrame",
4040
"AckMode",
41+
"AckableMessageFrame",
4142
"AnyClientFrame",
4243
"AnyRealServerFrame",
4344
"AnyServerFrame",
45+
"AutoAckSubscription",
4446
"BeginFrame",
4547
"Client",
4648
"CommitFrame",
@@ -57,13 +59,13 @@
5759
"FrameParser",
5860
"Heartbeat",
5961
"HeartbeatFrame",
62+
"ManualAckSubscription",
6063
"MessageFrame",
6164
"NackFrame",
6265
"ReceiptFrame",
6366
"SendFrame",
6467
"StompProtocolConnectionIssue",
6568
"SubscribeFrame",
66-
"Subscription",
6769
"Transaction",
6870
"UnsubscribeFrame",
6971
"UnsupportedProtocolVersion",

packages/stompman/stompman/client.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import inspect
3-
from collections.abc import AsyncGenerator, Awaitable, Callable
3+
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
44
from contextlib import AsyncExitStack, asynccontextmanager
55
from dataclasses import dataclass, field
66
from functools import partial
@@ -21,7 +21,7 @@
2121
ReceiptFrame,
2222
SendFrame,
2323
)
24-
from stompman.subscription import Subscription
24+
from stompman.subscription import AckableMessageFrame, ActiveSubscriptions, AutoAckSubscription, ManualAckSubscription
2525
from stompman.transaction import Transaction
2626

2727

@@ -47,7 +47,7 @@ class Client:
4747
connection_class: type[AbstractConnection] = Connection
4848

4949
_connection_manager: ConnectionManager = field(init=False)
50-
_active_subscriptions: dict[str, "Subscription"] = field(default_factory=dict, init=False)
50+
_active_subscriptions: ActiveSubscriptions = field(default_factory=dict, init=False)
5151
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
5252
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
5353
_heartbeat_task: asyncio.Task[None] = field(init=False)
@@ -113,7 +113,15 @@ async def _listen_to_frames(self) -> None:
113113
match frame:
114114
case MessageFrame():
115115
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
116-
task_group.create_task(subscription._run_handler(frame=frame)) # noqa: SLF001
116+
task_group.create_task(
117+
subscription._run_handler(frame=frame) # noqa: SLF001
118+
if isinstance(subscription, AutoAckSubscription)
119+
else subscription.handler(
120+
AckableMessageFrame(
121+
headers=frame.headers, body=frame.body, _subscription=subscription
122+
)
123+
)
124+
)
117125
case ErrorFrame():
118126
if self.on_error_frame:
119127
self.on_error_frame(frame)
@@ -152,8 +160,8 @@ async def subscribe(
152160
headers: dict[str, str] | None = None,
153161
on_suppressed_exception: Callable[[Exception, MessageFrame], Any],
154162
suppressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
155-
) -> "Subscription":
156-
subscription = Subscription(
163+
) -> "AutoAckSubscription":
164+
subscription = AutoAckSubscription(
157165
destination=destination,
158166
handler=handler,
159167
headers=headers,
@@ -165,3 +173,22 @@ async def subscribe(
165173
)
166174
await subscription._subscribe() # noqa: SLF001
167175
return subscription
176+
177+
async def subscribe_with_manual_ack(
178+
self,
179+
destination: str,
180+
handler: Callable[[AckableMessageFrame], Coroutine[Any, Any, Any]],
181+
*,
182+
ack: AckMode = "client-individual",
183+
headers: dict[str, str] | None = None,
184+
) -> "ManualAckSubscription":
185+
subscription = ManualAckSubscription(
186+
destination=destination,
187+
handler=handler,
188+
headers=headers,
189+
ack=ack,
190+
_connection_manager=self._connection_manager,
191+
_active_subscriptions=self._active_subscriptions,
192+
)
193+
await subscription._subscribe() # noqa: SLF001
194+
return subscription

packages/stompman/stompman/subscription.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Awaitable, Callable
1+
from collections.abc import Awaitable, Callable, Coroutine
22
from dataclasses import dataclass, field
33
from typing import Any
44
from uuid import uuid4
@@ -14,62 +14,79 @@
1414
UnsubscribeFrame,
1515
)
1616

17-
ActiveSubscriptions = dict[str, "Subscription"]
17+
ActiveSubscriptions = dict[str, "AutoAckSubscription | ManualAckSubscription"]
1818

1919

2020
@dataclass(kw_only=True, slots=True)
21-
class Subscription:
21+
class BaseSubscription:
2222
id: str = field(default_factory=lambda: _make_subscription_id(), init=False) # noqa: PLW0108
2323
destination: str
2424
headers: dict[str, str] | None
25-
handler: Callable[[MessageFrame], Awaitable[Any]]
2625
ack: AckMode
27-
on_suppressed_exception: Callable[[Exception, MessageFrame], Any]
28-
suppressed_exception_classes: tuple[type[Exception], ...]
2926
_connection_manager: ConnectionManager
3027
_active_subscriptions: ActiveSubscriptions
3128

32-
_should_handle_ack_nack: bool = field(init=False)
33-
34-
def __post_init__(self) -> None:
35-
self._should_handle_ack_nack = self.ack in {"client", "client-individual"}
36-
3729
async def _subscribe(self) -> None:
3830
await self._connection_manager.write_frame_reconnecting(
3931
SubscribeFrame.build(
4032
subscription_id=self.id, destination=self.destination, ack=self.ack, headers=self.headers
4133
)
4234
)
43-
self._active_subscriptions[self.id] = self
35+
self._active_subscriptions[self.id] = self # type: ignore[assignment]
4436

4537
async def unsubscribe(self) -> None:
4638
del self._active_subscriptions[self.id]
4739
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))
4840

41+
async def _nack(self, frame: MessageFrame) -> None:
42+
if self.id in self._active_subscriptions and (ack_id := frame.headers.get("ack")):
43+
await self._connection_manager.maybe_write_frame(
44+
NackFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
45+
)
46+
47+
async def _ack(self, frame: MessageFrame) -> None:
48+
if self.id in self._active_subscriptions and (ack_id := frame.headers["ack"]):
49+
await self._connection_manager.maybe_write_frame(
50+
AckFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
51+
)
52+
53+
54+
@dataclass(kw_only=True, slots=True)
55+
class AutoAckSubscription(BaseSubscription):
56+
handler: Callable[[MessageFrame], Awaitable[Any]]
57+
on_suppressed_exception: Callable[[Exception, MessageFrame], Any]
58+
suppressed_exception_classes: tuple[type[Exception], ...]
59+
_should_handle_ack_nack: bool = field(init=False)
60+
61+
def __post_init__(self) -> None:
62+
self._should_handle_ack_nack = self.ack in {"client", "client-individual"}
63+
4964
async def _run_handler(self, *, frame: MessageFrame) -> None:
5065
try:
5166
await self.handler(frame)
5267
except self.suppressed_exception_classes as exception:
53-
if (
54-
self._should_handle_ack_nack
55-
and self.id in self._active_subscriptions
56-
and (ack_id := frame.headers["ack"])
57-
):
58-
await self._connection_manager.maybe_write_frame(
59-
NackFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
60-
)
68+
if self._should_handle_ack_nack:
69+
await self._nack(frame)
6170
self.on_suppressed_exception(exception, frame)
6271
else:
63-
if (
64-
self._should_handle_ack_nack
65-
and self.id in self._active_subscriptions
66-
and (ack_id := frame.headers["ack"])
67-
):
68-
await self._connection_manager.maybe_write_frame(
69-
AckFrame(
70-
headers={"id": ack_id, "subscription": frame.headers["subscription"]},
71-
)
72-
)
72+
if self._should_handle_ack_nack:
73+
await self._ack(frame)
74+
75+
76+
@dataclass(kw_only=True, slots=True)
77+
class ManualAckSubscription(BaseSubscription):
78+
handler: Callable[["AckableMessageFrame"], Coroutine[Any, Any, Any]]
79+
80+
81+
@dataclass(frozen=True, kw_only=True, slots=True)
82+
class AckableMessageFrame(MessageFrame):
83+
_subscription: ManualAckSubscription
84+
85+
async def ack(self) -> None:
86+
await self._subscription._ack(self) # noqa: SLF001
87+
88+
async def nack(self) -> None:
89+
await self._subscription._nack(self) # noqa: SLF001
7390

7491

7592
def _make_subscription_id() -> str:

packages/stompman/test_stompman/test_subscription.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,35 @@ async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, fake
278278
)
279279

280280

281+
async def test_client_listen_manual_ack_nack_ok(monkeypatch: pytest.MonkeyPatch, faker: faker.Faker) -> None:
282+
subscription_id, destination, message_id, ack_id = faker.pystr(), faker.pystr(), faker.pystr(), faker.pystr()
283+
monkeypatch.setattr(stompman.subscription, "_make_subscription_id", mock.Mock(return_value=subscription_id))
284+
285+
message_frame = build_dataclass(
286+
MessageFrame,
287+
headers={"destination": destination, "message-id": message_id, "subscription": subscription_id, "ack": ack_id},
288+
)
289+
connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([message_frame]))
290+
291+
async def handle_message(message: stompman.subscription.AckableMessageFrame) -> None:
292+
await message.ack()
293+
await message.nack()
294+
295+
async with EnrichedClient(connection_class=connection_class) as client:
296+
subscription = await client.subscribe_with_manual_ack(destination, handle_message)
297+
await asyncio.sleep(0)
298+
await asyncio.sleep(0)
299+
await subscription.unsubscribe()
300+
301+
assert collected_frames == enrich_expected_frames(
302+
SubscribeFrame(headers={"ack": "client-individual", "destination": destination, "id": subscription_id}),
303+
message_frame,
304+
AckFrame(headers={"subscription": subscription_id, "id": ack_id}),
305+
NackFrame(headers={"subscription": subscription_id, "id": ack_id}),
306+
UnsubscribeFrame(headers={"id": subscription_id}),
307+
)
308+
309+
281310
async def test_client_listen_raises_on_aexit(monkeypatch: pytest.MonkeyPatch, faker: faker.Faker) -> None:
282311
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
283312

0 commit comments

Comments
 (0)