Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions fila/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from fila.async_client import AsyncClient
from fila.client import Client
from fila.errors import (
BatchEnqueueError,
EnqueueError,
FilaError,
MessageNotFoundError,
QueueNotFoundError,
RPCError,
)
from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger
from fila.types import AccumulatorMode, ConsumeMessage, EnqueueResult, Linger

__all__ = [
"AccumulatorMode",
"AsyncClient",
"BatchEnqueueError",
"BatchEnqueueResult",
"BatchMode",
"Client",
"ConsumeMessage",
"EnqueueError",
"EnqueueResult",
"FilaError",
"Linger",
"MessageNotFoundError",
Expand Down
116 changes: 65 additions & 51 deletions fila/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator

from fila.client import _proto_msg_to_consume_message
from fila.client import _proto_enqueue_result_to_sdk, _proto_msg_to_consume_message
from fila.errors import (
EnqueueError,
_map_ack_error,
_map_batch_enqueue_error,
_map_consume_error,
_map_enqueue_error,
_map_enqueue_result_error,
_map_nack_error,
)
from fila.types import BatchEnqueueResult, ConsumeMessage
from fila.types import ConsumeMessage, EnqueueResult
from fila.v1 import service_pb2, service_pb2_grpc


Expand Down Expand Up @@ -125,7 +126,7 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None:
class AsyncClient:
"""Asynchronous client for the Fila message broker.

Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack,
Wraps the hot-path gRPC operations: enqueue, enqueue_many, consume, ack,
nack.

Usage::
Expand Down Expand Up @@ -255,26 +256,35 @@ async def enqueue(
try:
resp = await self._stub.Enqueue(
service_pb2.EnqueueRequest(
queue=queue,
headers=headers or {},
payload=payload,
messages=[
service_pb2.EnqueueMessage(
queue=queue,
headers=headers or {},
payload=payload,
)
]
)
)
except grpc.RpcError as e:
raise _map_enqueue_error(e) from e
return str(resp.message_id)

async def batch_enqueue(
result = resp.results[0]
which = result.WhichOneof("result")
if which == "message_id":
return str(result.message_id)
raise _map_enqueue_result_error(result.error.code, result.error.message)

async def enqueue_many(
self,
messages: list[tuple[str, dict[str, str] | None, bytes]],
) -> list[BatchEnqueueResult]:
) -> list[EnqueueResult]:
"""Enqueue multiple messages in a single RPC.

Args:
messages: List of (queue, headers, payload) tuples.

Returns:
List of ``BatchEnqueueResult`` objects, one per input message.
List of ``EnqueueResult`` objects, one per input message.
Each result has either a ``message_id`` (success) or ``error``
(per-message failure).

Expand All @@ -283,7 +293,7 @@ async def batch_enqueue(
RPCError: For unexpected gRPC failures.
"""
proto_messages = [
service_pb2.EnqueueRequest(
service_pb2.EnqueueMessage(
queue=q,
headers=h or {},
payload=p,
Expand All @@ -292,26 +302,13 @@ async def batch_enqueue(
]

try:
resp = await self._stub.BatchEnqueue(
service_pb2.BatchEnqueueRequest(messages=proto_messages)
resp = await self._stub.Enqueue(
service_pb2.EnqueueRequest(messages=proto_messages)
)
except grpc.RpcError as e:
raise _map_batch_enqueue_error(e) from e

results: list[BatchEnqueueResult] = []
for r in resp.results:
if r.HasField("success"):
results.append(
BatchEnqueueResult(
message_id=str(r.success.message_id),
error=None,
)
)
else:
results.append(
BatchEnqueueResult(message_id=None, error=r.error)
)
return results
raise _map_enqueue_error(e) from e

return [_proto_enqueue_result_to_sdk(r) for r in resp.results]

async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]:
"""Open a streaming consumer on the specified queue.
Expand Down Expand Up @@ -363,25 +360,12 @@ async def _consume_iter(
self,
stream: Any,
) -> AsyncIterator[ConsumeMessage]:
"""Internal async generator reading from the gRPC stream.

Handles both singular ``message`` field (backward compatible) and
repeated ``messages`` field (batched delivery).
"""
"""Internal async generator reading from the gRPC stream."""
try:
async for resp in stream:
# Check batched messages first (repeated field).
if len(resp.messages) > 0:
for msg in resp.messages:
if msg is not None and msg.ByteSize():
yield _proto_msg_to_consume_message(msg)
continue

# Fall back to singular message field.
msg = resp.message
if msg is None or not msg.ByteSize():
continue # keepalive
yield _proto_msg_to_consume_message(msg)
for msg in resp.messages:
if msg is not None and msg.ByteSize():
yield _proto_msg_to_consume_message(msg)
except grpc.RpcError:
return

Expand All @@ -399,12 +383,26 @@ async def ack(self, queue: str, msg_id: str) -> None:
RPCError: For unexpected gRPC failures.
"""
try:
await self._stub.Ack(
service_pb2.AckRequest(queue=queue, message_id=msg_id)
resp = await self._stub.Ack(
service_pb2.AckRequest(
messages=[service_pb2.AckMessage(queue=queue, message_id=msg_id)]
)
)
except grpc.RpcError as e:
raise _map_ack_error(e) from e

# Check per-message result for errors.
if resp.results:
result = resp.results[0]
which = result.WhichOneof("result")
if which == "error":
from fila.errors import MessageNotFoundError, RPCError as _RPCError

ack_err = result.error
if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND:
raise MessageNotFoundError(f"ack: {ack_err.message}")
raise _RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}")

async def nack(self, queue: str, msg_id: str, error: str) -> None:
"""Negatively acknowledge a message that failed processing.

Expand All @@ -421,10 +419,26 @@ async def nack(self, queue: str, msg_id: str, error: str) -> None:
RPCError: For unexpected gRPC failures.
"""
try:
await self._stub.Nack(
resp = await self._stub.Nack(
service_pb2.NackRequest(
queue=queue, message_id=msg_id, error=error
messages=[
service_pb2.NackMessage(
queue=queue, message_id=msg_id, error=error
)
]
)
)
except grpc.RpcError as e:
raise _map_nack_error(e) from e

# Check per-message result for errors.
if resp.results:
result = resp.results[0]
which = result.WhichOneof("result")
if which == "error":
from fila.errors import MessageNotFoundError, RPCError as _RPCError

nack_err = result.error
if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND:
raise MessageNotFoundError(f"nack: {nack_err.message}")
raise _RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}")
Loading
Loading