diff --git a/fila/__init__.py b/fila/__init__.py index a273836..732fc43 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -1,20 +1,25 @@ -"""Fila — Python client SDK for the Fila message broker.""" +"""Fila -- Python client SDK for the Fila message broker.""" from fila.async_client import AsyncClient from fila.client import Client from fila.errors import ( + BatchEnqueueError, FilaError, MessageNotFoundError, QueueNotFoundError, RPCError, ) -from fila.types import ConsumeMessage +from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger __all__ = [ "AsyncClient", + "BatchEnqueueError", + "BatchEnqueueResult", + "BatchMode", "Client", "ConsumeMessage", "FilaError", + "Linger", "MessageNotFoundError", "QueueNotFoundError", "RPCError", diff --git a/fila/async_client.py b/fila/async_client.py index 8bab27b..8e06b1e 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -10,8 +10,15 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage +from fila.client import _proto_msg_to_consume_message +from fila.errors import ( + _map_ack_error, + _map_batch_enqueue_error, + _map_consume_error, + _map_enqueue_error, + _map_nack_error, +) +from fila.types import BatchEnqueueResult, ConsumeMessage from fila.v1 import service_pb2, service_pb2_grpc @@ -118,7 +125,8 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: class AsyncClient: """Asynchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + nack. Usage:: @@ -256,6 +264,55 @@ async def enqueue( raise _map_enqueue_error(e) from e return str(resp.message_id) + async def batch_enqueue( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[BatchEnqueueResult]: + """Enqueue multiple messages in a single RPC. + + Args: + messages: List of (queue, headers, payload) tuples. + + Returns: + List of ``BatchEnqueueResult`` objects, one per input message. + Each result has either a ``message_id`` (success) or ``error`` + (per-message failure). + + Raises: + QueueNotFoundError: If a referenced queue does not exist. + RPCError: For unexpected gRPC failures. + """ + proto_messages = [ + service_pb2.EnqueueRequest( + queue=q, + headers=h or {}, + payload=p, + ) + for q, h, p in messages + ] + + try: + resp = await self._stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest(messages=proto_messages) + ) + except grpc.RpcError as e: + raise _map_batch_enqueue_error(e) from e + + results: list[BatchEnqueueResult] = [] + for r in resp.results: + if r.HasField("success"): + results.append( + BatchEnqueueResult( + message_id=str(r.success.message_id), + error=None, + ) + ) + else: + results.append( + BatchEnqueueResult(message_id=None, error=r.error) + ) + return results + async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -306,22 +363,25 @@ async def _consume_iter( self, stream: Any, ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading from the gRPC stream.""" + """Internal async generator reading from the gRPC stream. + + Handles both singular ``message`` field (backward compatible) and + repeated ``messages`` field (batched delivery). + """ try: async for resp in stream: + # Check batched messages first (repeated field). + if len(resp.messages) > 0: + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) + continue + + # Fall back to singular message field. msg = resp.message if msg is None or not msg.ByteSize(): continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return diff --git a/fila/batcher.py b/fila/batcher.py new file mode 100644 index 0000000..c57964a --- /dev/null +++ b/fila/batcher.py @@ -0,0 +1,267 @@ +"""Background batcher for opportunistic and linger-based enqueue batching.""" + +from __future__ import annotations + +import queue +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Any + +import grpc + +from fila.errors import BatchEnqueueError, _map_enqueue_error +from fila.types import BatchEnqueueResult +from fila.v1 import service_pb2 + +if TYPE_CHECKING: + from fila.v1 import service_pb2_grpc + + +# Sentinel that signals the batcher thread to stop. +_STOP = object() + +# Maximum batch size when none is configured. +_DEFAULT_MAX_BATCH_SIZE = 1000 + + +class _EnqueueRequest: + """Internal envelope pairing a proto request with its result future.""" + + __slots__ = ("proto", "future") + + def __init__( + self, + proto: service_pb2.EnqueueRequest, + future: Future[str], + ) -> None: + self.proto = proto + self.future = future + + +def _msg_to_consume_result( + proto_result: Any, +) -> BatchEnqueueResult: + """Convert a proto ``BatchEnqueueResult`` to the SDK type.""" + if proto_result.HasField("success"): + return BatchEnqueueResult( + message_id=proto_result.success.message_id, + error=None, + ) + return BatchEnqueueResult( + message_id=None, + error=proto_result.error, + ) + + +def _flush_single( + stub: service_pb2_grpc.FilaServiceStub, + req: _EnqueueRequest, +) -> None: + """Send a single message via the singular Enqueue RPC. + + This preserves the specific error types (QueueNotFoundError, etc.) + that callers of ``enqueue()`` expect. + """ + try: + resp = stub.Enqueue(req.proto) + req.future.set_result(str(resp.message_id)) + except grpc.RpcError as e: + req.future.set_exception(_map_enqueue_error(e)) + except Exception as e: + req.future.set_exception(e) + + +def _flush_batch( + stub: service_pb2_grpc.FilaServiceStub, + batch: list[_EnqueueRequest], +) -> None: + """Send a batch of messages via the BatchEnqueue RPC. + + On RPC-level failure, every future in the batch receives a + ``BatchEnqueueError``. On success, each future gets either its + message ID or a per-message error string wrapped in a + ``BatchEnqueueError``. + """ + try: + resp = stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest( + messages=[r.proto for r in batch], + ) + ) + except grpc.RpcError as e: + err = BatchEnqueueError(f"batch enqueue rpc failed: {e.details()}") + for r in batch: + r.future.set_exception(err) + return + except Exception as e: + for r in batch: + r.future.set_exception(e) + return + + # Pair each result with its request future. + for i, result in enumerate(resp.results): + if i >= len(batch): + break + req = batch[i] + if result.HasField("success"): + req.future.set_result(str(result.success.message_id)) + else: + req.future.set_exception( + BatchEnqueueError(f"enqueue failed: {result.error}") + ) + + +class AutoBatcher: + """Opportunistic batcher: drains a queue and flushes in batches. + + A background daemon thread blocks on the first message, then non-blocking + drains any additional messages that arrived during processing and flushes + them as a single batch via a thread pool executor. + """ + + def __init__( + self, + stub: service_pb2_grpc.FilaServiceStub, + max_batch_size: int = _DEFAULT_MAX_BATCH_SIZE, + max_workers: int = 4, + ) -> None: + self._stub = stub + self._max_batch_size = max_batch_size + self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: + """Submit a message for batched enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueRequest(proto, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the batcher. + + Blocks until all pending messages have been flushed or *timeout* + seconds have elapsed. + """ + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: + """Update the gRPC stub (e.g. after leader-hint reconnect).""" + self._stub = stub + + def _run(self) -> None: + """Background loop: block for first item, drain rest, flush.""" + while True: + # Block until at least one item arrives. + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueRequest) + batch: list[_EnqueueRequest] = [first] + + # Non-blocking drain of any additional queued messages. + while len(batch) < self._max_batch_size: + try: + item = self._queue.get_nowait() + except queue.Empty: + break + if item is _STOP: + # Flush what we have, then stop. + self._flush(batch) + return + assert isinstance(item, _EnqueueRequest) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueRequest]) -> None: + """Dispatch a batch to the executor for concurrent RPC.""" + if len(batch) == 1: + # Single-item optimization: use singular Enqueue RPC. + self._executor.submit(_flush_single, self._stub, batch[0]) + else: + self._executor.submit(_flush_batch, self._stub, batch) + + +class LingerBatcher: + """Timer-based batcher: holds messages for up to linger_ms or batch_size. + + A background daemon thread accumulates messages and flushes when either + the batch reaches ``batch_size`` or ``linger_ms`` milliseconds have + elapsed since the first message in the current batch arrived. + """ + + def __init__( + self, + stub: service_pb2_grpc.FilaServiceStub, + linger_ms: float, + batch_size: int, + max_workers: int = 4, + ) -> None: + self._stub = stub + self._linger_s = linger_ms / 1000.0 + self._batch_size = batch_size + self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: + """Submit a message for batched enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueRequest(proto, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the batcher.""" + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: + """Update the gRPC stub (e.g. after leader-hint reconnect).""" + self._stub = stub + + def _run(self) -> None: + """Background loop: accumulate up to batch_size or linger timeout.""" + import time + + while True: + # Block until at least one item arrives. + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueRequest) + batch: list[_EnqueueRequest] = [first] + + # Track wall-clock deadline from when first message arrived. + deadline = time.monotonic() + self._linger_s + + # Accumulate more items until batch_size or linger timeout. + while len(batch) < self._batch_size: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + try: + item = self._queue.get(timeout=remaining) + except queue.Empty: + break + if item is _STOP: + self._flush(batch) + return + assert isinstance(item, _EnqueueRequest) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueRequest]) -> None: + """Dispatch a batch to the executor for concurrent RPC.""" + if len(batch) == 1: + self._executor.submit(_flush_single, self._stub, batch[0]) + else: + self._executor.submit(_flush_batch, self._stub, batch) diff --git a/fila/client.py b/fila/client.py index 891907a..0d7e49a 100644 --- a/fila/client.py +++ b/fila/client.py @@ -6,8 +6,15 @@ import grpc -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage +from fila.batcher import AutoBatcher, LingerBatcher +from fila.errors import ( + _map_ack_error, + _map_batch_enqueue_error, + _map_consume_error, + _map_enqueue_error, + _map_nack_error, +) +from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger from fila.v1 import service_pb2, service_pb2_grpc if TYPE_CHECKING: @@ -33,6 +40,19 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: return None +def _proto_msg_to_consume_message(msg: Any) -> ConsumeMessage: + """Convert a protobuf Message to a ConsumeMessage.""" + metadata = msg.metadata + return ConsumeMessage( + id=msg.id, + headers=dict(msg.headers), + payload=bytes(msg.payload), + fairness_key=metadata.fairness_key if metadata else "", + attempt_count=metadata.attempt_count if metadata else 0, + queue=metadata.queue_id if metadata else "", + ) + + class _ClientCallDetails( grpc.ClientCallDetails, # type: ignore[misc] ): @@ -102,7 +122,8 @@ def intercept_unary_stream( class Client: """Synchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + nack. Usage:: @@ -117,6 +138,17 @@ class Client: with Client("localhost:5555") as client: client.enqueue("my-queue", None, b"hello") + Batch modes:: + + # AUTO (default): opportunistic batching via background thread + client = Client("localhost:5555") + + # DISABLED: each enqueue() is a direct RPC + client = Client("localhost:5555", batch_mode=BatchMode.DISABLED) + + # LINGER: timer-based forced batching + client = Client("localhost:5555", batch_mode=Linger(linger_ms=10, batch_size=100)) + TLS (system trust store):: client = Client("localhost:5555", tls=True) @@ -147,6 +179,8 @@ def __init__( client_cert: bytes | None = None, client_key: bytes | None = None, api_key: str | None = None, + batch_mode: BatchMode | Linger = BatchMode.AUTO, + max_batch_size: int = 1000, ) -> None: """Connect to a Fila broker at the given address. @@ -161,6 +195,10 @@ def __init__( client_key: PEM-encoded client private key for mutual TLS (optional). api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. + batch_mode: Controls how ``enqueue()`` routes messages. Defaults to + ``BatchMode.AUTO`` (opportunistic batching). + max_batch_size: Maximum number of messages per batch when using + ``BatchMode.AUTO``. Defaults to 1000. """ self._tls = tls self._ca_cert = ca_cert @@ -177,6 +215,21 @@ def __init__( self._channel = self._make_channel(addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + # Set up the batcher based on the chosen mode. + self._batcher: AutoBatcher | LingerBatcher | None = None + if isinstance(batch_mode, Linger): + self._batcher = LingerBatcher( + self._stub, + linger_ms=batch_mode.linger_ms, + batch_size=batch_mode.batch_size, + ) + elif batch_mode is BatchMode.AUTO: + self._batcher = AutoBatcher( + self._stub, + max_batch_size=max_batch_size, + ) + # BatchMode.DISABLED: self._batcher stays None + def _make_channel(self, addr: str) -> grpc.Channel: """Create a gRPC channel to the given address using stored credentials.""" use_tls = self._tls or self._ca_cert is not None @@ -198,7 +251,9 @@ def _make_channel(self, addr: str) -> grpc.Channel: return channel def close(self) -> None: - """Close the underlying gRPC channel.""" + """Drain pending batched messages and close the underlying gRPC channel.""" + if self._batcher is not None: + self._batcher.close() self._channel.close() def __enter__(self) -> Client: @@ -215,6 +270,13 @@ def enqueue( ) -> str: """Enqueue a message to the specified queue. + When a batcher is active (``BatchMode.AUTO`` or ``Linger``), the + message is submitted to the background batcher and this call blocks + until the batch is flushed and the result is available. + + When batching is disabled (``BatchMode.DISABLED``), this call makes + a direct synchronous RPC. + Args: queue: Target queue name. headers: Optional message headers. @@ -224,21 +286,79 @@ def enqueue( Broker-assigned message ID (UUIDv7). Raises: - QueueNotFoundError: If the queue does not exist. + QueueNotFoundError: If the queue does not exist (DISABLED mode). + BatchEnqueueError: If the batch RPC fails (AUTO/LINGER mode). RPCError: For unexpected gRPC failures. """ + proto = service_pb2.EnqueueRequest( + queue=queue, + headers=headers or {}, + payload=payload, + ) + + if self._batcher is not None: + future = self._batcher.submit(proto) + return future.result() + + # Direct RPC (DISABLED mode). try: - resp = self._stub.Enqueue( - service_pb2.EnqueueRequest( - queue=queue, - headers=headers or {}, - payload=payload, - ) - ) + resp = self._stub.Enqueue(proto) except grpc.RpcError as e: raise _map_enqueue_error(e) from e return str(resp.message_id) + def batch_enqueue( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[BatchEnqueueResult]: + """Enqueue multiple messages in a single RPC. + + This is an explicit batch operation that always uses the BatchEnqueue + RPC regardless of the batch_mode setting. + + Args: + messages: List of (queue, headers, payload) tuples. + + Returns: + List of ``BatchEnqueueResult`` objects, one per input message. + Each result has either a ``message_id`` (success) or ``error`` + (per-message failure). + + Raises: + QueueNotFoundError: If a referenced queue does not exist. + RPCError: For unexpected gRPC failures. + """ + proto_messages = [ + service_pb2.EnqueueRequest( + queue=q, + headers=h or {}, + payload=p, + ) + for q, h, p in messages + ] + + try: + resp = self._stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest(messages=proto_messages) + ) + except grpc.RpcError as e: + raise _map_batch_enqueue_error(e) from e + + results: list[BatchEnqueueResult] = [] + for r in resp.results: + if r.HasField("success"): + results.append( + BatchEnqueueResult( + message_id=str(r.success.message_id), + error=None, + ) + ) + else: + results.append( + BatchEnqueueResult(message_id=None, error=r.error) + ) + return results + def consume(self, queue: str) -> Iterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -278,6 +398,8 @@ def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: self._channel.close() self._channel = self._make_channel(leader_addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + if self._batcher is not None: + self._batcher.update_stub(self._stub) try: return self._stub.Consume( service_pb2.ConsumeRequest(queue=queue) @@ -289,22 +411,25 @@ def _consume_iter( self, stream: Any, ) -> Iterator[ConsumeMessage]: - """Internal generator reading from the gRPC stream.""" + """Internal generator reading from the gRPC stream. + + Handles both singular ``message`` field (backward compatible) and + repeated ``messages`` field (batched delivery). + """ try: for resp in stream: + # Check batched messages first (repeated field). + if len(resp.messages) > 0: + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) + continue + + # Fall back to singular message field. msg = resp.message if msg is None or not msg.ByteSize(): continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return diff --git a/fila/errors.py b/fila/errors.py index 346c1c6..40e76ee 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -26,6 +26,15 @@ def __init__(self, code: grpc.StatusCode, message: str) -> None: super().__init__(f"rpc error (code = {code.name}): {message}") +class BatchEnqueueError(FilaError): + """Raised when a batched enqueue fails at the RPC level. + + Individual per-message failures are reported via ``BatchEnqueueResult.error`` + and do not raise this exception. This is raised only when the entire batch + RPC fails (e.g., network error, server unavailable). + """ + + def _map_enqueue_error(err: grpc.RpcError) -> FilaError: """Map a gRPC error from an enqueue call to a Fila exception.""" code = err.code() @@ -56,3 +65,11 @@ def _map_nack_error(err: grpc.RpcError) -> FilaError: if code == grpc.StatusCode.NOT_FOUND: return MessageNotFoundError(f"nack: {err.details()}") return RPCError(code, err.details() or "") + + +def _map_batch_enqueue_error(err: grpc.RpcError) -> FilaError: + """Map a gRPC error from a batch enqueue call to a Fila exception.""" + code = err.code() + if code == grpc.StatusCode.NOT_FOUND: + return QueueNotFoundError(f"batch_enqueue: {err.details()}") + return RPCError(code, err.details() or "") diff --git a/fila/types.py b/fila/types.py index 2474228..54ab034 100644 --- a/fila/types.py +++ b/fila/types.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum, auto @dataclass(frozen=True) @@ -15,3 +16,48 @@ class ConsumeMessage: fairness_key: str attempt_count: int queue: str + + +@dataclass(frozen=True) +class BatchEnqueueResult: + """Result for a single message within a batch enqueue operation. + + Exactly one of ``message_id`` or ``error`` is set. + """ + + message_id: str | None + error: str | None + + @property + def is_success(self) -> bool: + """Return True if this message was enqueued successfully.""" + return self.message_id is not None + + +class BatchMode(Enum): + """Controls how ``enqueue()`` routes messages to the broker. + + - ``AUTO``: Opportunistic batching via a background thread. At low load + messages are sent individually; at high load they cluster into batches. + This is the default. + - ``DISABLED``: No batching. Each ``enqueue()`` call is a direct RPC. + """ + + AUTO = auto() + DISABLED = auto() + + +@dataclass(frozen=True) +class Linger: + """Timer-based forced batching mode. + + Messages are held for up to ``linger_ms`` milliseconds or until + ``batch_size`` messages accumulate, whichever comes first. + + Args: + linger_ms: Maximum time to hold a message before flushing (milliseconds). + batch_size: Maximum number of messages per batch. + """ + + linger_ms: float + batch_size: int diff --git a/fila/v1/messages_pb2_grpc.py b/fila/v1/messages_pb2_grpc.py index fa0dc71..d27d27c 100644 --- a/fila/v1/messages_pb2_grpc.py +++ b/fila/v1/messages_pb2_grpc.py @@ -4,7 +4,7 @@ import warnings -GRPC_GENERATED_VERSION = '1.78.0' +GRPC_GENERATED_VERSION = '1.78.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False diff --git a/fila/v1/service_pb2.py b/fila/v1/service_pb2.py index 11ad1f0..7f04078 100644 --- a/fila/v1/service_pb2.py +++ b/fila/v1/service_pb2.py @@ -25,7 +25,7 @@ from fila.v1 import messages_pb2 as fila_dot_v1_dot_messages__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"4\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse2\xf2\x01\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"X\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\x12\"\n\x08messages\x18\x02 \x03(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse\"@\n\x13\x42\x61tchEnqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueRequest\"D\n\x14\x42\x61tchEnqueueResponse\x12,\n\x07results\x18\x01 \x03(\x0b\x32\x1b.fila.v1.BatchEnqueueResult\"\\\n\x12\x42\x61tchEnqueueResult\x12+\n\x07success\x18\x01 \x01(\x0b\x32\x18.fila.v1.EnqueueResponseH\x00\x12\x0f\n\x05\x65rror\x18\x02 \x01(\tH\x00\x42\x08\n\x06result2\xbf\x02\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12K\n\x0c\x42\x61tchEnqueue\x12\x1c.fila.v1.BatchEnqueueRequest\x1a\x1d.fila.v1.BatchEnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -43,15 +43,21 @@ _globals['_CONSUMEREQUEST']._serialized_start=251 _globals['_CONSUMEREQUEST']._serialized_end=282 _globals['_CONSUMERESPONSE']._serialized_start=284 - _globals['_CONSUMERESPONSE']._serialized_end=336 - _globals['_ACKREQUEST']._serialized_start=338 - _globals['_ACKREQUEST']._serialized_end=385 - _globals['_ACKRESPONSE']._serialized_start=387 - _globals['_ACKRESPONSE']._serialized_end=400 - _globals['_NACKREQUEST']._serialized_start=402 - _globals['_NACKREQUEST']._serialized_end=465 - _globals['_NACKRESPONSE']._serialized_start=467 - _globals['_NACKRESPONSE']._serialized_end=481 - _globals['_FILASERVICE']._serialized_start=484 - _globals['_FILASERVICE']._serialized_end=726 + _globals['_CONSUMERESPONSE']._serialized_end=372 + _globals['_ACKREQUEST']._serialized_start=374 + _globals['_ACKREQUEST']._serialized_end=421 + _globals['_ACKRESPONSE']._serialized_start=423 + _globals['_ACKRESPONSE']._serialized_end=436 + _globals['_NACKREQUEST']._serialized_start=438 + _globals['_NACKREQUEST']._serialized_end=501 + _globals['_NACKRESPONSE']._serialized_start=503 + _globals['_NACKRESPONSE']._serialized_end=517 + _globals['_BATCHENQUEUEREQUEST']._serialized_start=519 + _globals['_BATCHENQUEUEREQUEST']._serialized_end=583 + _globals['_BATCHENQUEUERESPONSE']._serialized_start=585 + _globals['_BATCHENQUEUERESPONSE']._serialized_end=653 + _globals['_BATCHENQUEUERESULT']._serialized_start=655 + _globals['_BATCHENQUEUERESULT']._serialized_end=747 + _globals['_FILASERVICE']._serialized_start=750 + _globals['_FILASERVICE']._serialized_end=1069 # @@protoc_insertion_point(module_scope) diff --git a/fila/v1/service_pb2.pyi b/fila/v1/service_pb2.pyi index c6478c4..ca1e820 100644 --- a/fila/v1/service_pb2.pyi +++ b/fila/v1/service_pb2.pyi @@ -2,7 +2,7 @@ from fila.v1 import messages_pb2 as _messages_pb2 from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from collections.abc import Mapping as _Mapping +from collections.abc import Iterable as _Iterable, Mapping as _Mapping from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor @@ -37,10 +37,12 @@ class ConsumeRequest(_message.Message): def __init__(self, queue: _Optional[str] = ...) -> None: ... class ConsumeResponse(_message.Message): - __slots__ = ("message",) + __slots__ = ("message", "messages") MESSAGE_FIELD_NUMBER: _ClassVar[int] + MESSAGES_FIELD_NUMBER: _ClassVar[int] message: _messages_pb2.Message - def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ...) -> None: ... + messages: _containers.RepeatedCompositeFieldContainer[_messages_pb2.Message] + def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ..., messages: _Optional[_Iterable[_Union[_messages_pb2.Message, _Mapping]]] = ...) -> None: ... class AckRequest(_message.Message): __slots__ = ("queue", "message_id") @@ -67,3 +69,23 @@ class NackRequest(_message.Message): class NackResponse(_message.Message): __slots__ = () def __init__(self) -> None: ... + +class BatchEnqueueRequest(_message.Message): + __slots__ = ("messages",) + MESSAGES_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[EnqueueRequest] + def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueRequest, _Mapping]]] = ...) -> None: ... + +class BatchEnqueueResponse(_message.Message): + __slots__ = ("results",) + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[BatchEnqueueResult] + def __init__(self, results: _Optional[_Iterable[_Union[BatchEnqueueResult, _Mapping]]] = ...) -> None: ... + +class BatchEnqueueResult(_message.Message): + __slots__ = ("success", "error") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + success: EnqueueResponse + error: str + def __init__(self, success: _Optional[_Union[EnqueueResponse, _Mapping]] = ..., error: _Optional[str] = ...) -> None: ... diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py index 663ae2a..0ef11e1 100644 --- a/fila/v1/service_pb2_grpc.py +++ b/fila/v1/service_pb2_grpc.py @@ -5,7 +5,7 @@ from fila.v1 import service_pb2 as fila_dot_v1_dot_service__pb2 -GRPC_GENERATED_VERSION = '1.78.0' +GRPC_GENERATED_VERSION = '1.78.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -40,6 +40,11 @@ def __init__(self, channel): request_serializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, response_deserializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, _registered_method=True) + self.BatchEnqueue = channel.unary_unary( + '/fila.v1.FilaService/BatchEnqueue', + request_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + _registered_method=True) self.Consume = channel.unary_stream( '/fila.v1.FilaService/Consume', request_serializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, @@ -67,6 +72,12 @@ def Enqueue(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def BatchEnqueue(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Consume(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -93,6 +104,11 @@ def add_FilaServiceServicer_to_server(servicer, server): request_deserializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.FromString, response_serializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.SerializeToString, ), + 'BatchEnqueue': grpc.unary_unary_rpc_method_handler( + servicer.BatchEnqueue, + request_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.FromString, + response_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.SerializeToString, + ), 'Consume': grpc.unary_stream_rpc_method_handler( servicer.Consume, request_deserializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.FromString, @@ -147,6 +163,33 @@ def Enqueue(request, metadata, _registered_method=True) + @staticmethod + def BatchEnqueue(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaService/BatchEnqueue', + fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, + fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def Consume(request, target, diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto index f14fdd0..fc0f710 100644 --- a/proto/fila/v1/service.proto +++ b/proto/fila/v1/service.proto @@ -6,6 +6,7 @@ import "fila/v1/messages.proto"; // Hot-path RPCs for producers and consumers. service FilaService { rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); + rpc BatchEnqueue(BatchEnqueueRequest) returns (BatchEnqueueResponse); rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); rpc Ack(AckRequest) returns (AckResponse); rpc Nack(NackRequest) returns (NackResponse); @@ -26,7 +27,8 @@ message ConsumeRequest { } message ConsumeResponse { - Message message = 1; + Message message = 1; // Single message (backward compatible, used when batch size is 1) + repeated Message messages = 2; // Batched messages (populated when server sends multiple at once) } message AckRequest { @@ -43,3 +45,18 @@ message NackRequest { } message NackResponse {} + +message BatchEnqueueRequest { + repeated EnqueueRequest messages = 1; +} + +message BatchEnqueueResponse { + repeated BatchEnqueueResult results = 1; +} + +message BatchEnqueueResult { + oneof result { + EnqueueResponse success = 1; + string error = 2; + } +} diff --git a/tests/test_batch_integration.py b/tests/test_batch_integration.py new file mode 100644 index 0000000..09aefb9 --- /dev/null +++ b/tests/test_batch_integration.py @@ -0,0 +1,220 @@ +"""Integration tests for batch enqueue and smart batching. + +These tests require a running fila-server binary. They are skipped +automatically when the server is not found (local dev). +""" + +from __future__ import annotations + +import pytest + +import fila + + +class TestBatchEnqueue: + """Integration tests for the explicit batch_enqueue method.""" + + def test_batch_enqueue_multiple_messages(self, server: object) -> None: + """batch_enqueue sends multiple messages in one RPC and returns per-message results.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch", {"idx": "0"}, b"payload-0"), + ("test-batch", {"idx": "1"}, b"payload-1"), + ("test-batch", {"idx": "2"}, b"payload-2"), + ]) + + assert len(results) == 3 + for r in results: + assert r.is_success + assert r.message_id is not None + assert r.error is None + + # All message IDs should be unique. + ids = [r.message_id for r in results] + assert len(set(ids)) == 3 + + def test_batch_enqueue_single_message(self, server: object) -> None: + """batch_enqueue works with a single message.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch-single") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch-single", None, b"solo"), + ]) + + assert len(results) == 1 + assert results[0].is_success + assert results[0].message_id is not None + + def test_batch_enqueue_consume_verify(self, server: object) -> None: + """Messages enqueued via batch_enqueue can be consumed and acked.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch-consume") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch-consume", {"k": "v"}, b"batch-msg"), + ]) + assert results[0].is_success + + stream = client.consume("test-batch-consume") + msg = next(stream) + + assert msg.id == results[0].message_id + assert msg.headers["k"] == "v" + assert msg.payload == b"batch-msg" + + client.ack("test-batch-consume", msg.id) + + +class TestAsyncBatchEnqueue: + """Integration tests for the async batch_enqueue method.""" + + @pytest.mark.asyncio + async def test_async_batch_enqueue(self, server: object) -> None: + """Async batch_enqueue sends multiple messages.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-async-batch") + + async with fila.AsyncClient(server.addr) as client: + results = await client.batch_enqueue([ + ("test-async-batch", None, b"async-0"), + ("test-async-batch", None, b"async-1"), + ]) + + assert len(results) == 2 + for r in results: + assert r.is_success + assert r.message_id is not None + + +class TestSmartBatching: + """Integration tests for smart batching (BatchMode.AUTO).""" + + def test_auto_mode_enqueue(self, server: object) -> None: + """AUTO mode enqueues messages through the batcher.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-batch") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: + msg_id = client.enqueue("test-auto-batch", None, b"auto-msg") + assert msg_id != "" + + # Verify the message was actually enqueued. + stream = client.consume("test-auto-batch") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"auto-msg" + client.ack("test-auto-batch", msg.id) + + def test_auto_mode_multiple_messages(self, server: object) -> None: + """AUTO mode handles multiple sequential enqueues.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-multi") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: + ids = [] + for i in range(5): + msg_id = client.enqueue( + "test-auto-multi", None, f"msg-{i}".encode() + ) + assert msg_id != "" + ids.append(msg_id) + + # All IDs should be unique. + assert len(set(ids)) == 5 + + def test_disabled_mode_enqueue(self, server: object) -> None: + """DISABLED mode sends each enqueue as a direct RPC.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-disabled") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + msg_id = client.enqueue("test-disabled", None, b"direct") + assert msg_id != "" + + stream = client.consume("test-disabled") + msg = next(stream) + assert msg.id == msg_id + client.ack("test-disabled", msg.id) + + def test_linger_mode_enqueue(self, server: object) -> None: + """LINGER mode enqueues messages through a timer-based batcher.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-linger") + + with fila.Client( + server.addr, + batch_mode=fila.Linger(linger_ms=50, batch_size=10), + ) as client: + msg_id = client.enqueue("test-linger", None, b"lingered") + assert msg_id != "" + + stream = client.consume("test-linger") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"lingered" + client.ack("test-linger", msg.id) + + def test_default_mode_is_auto(self, server: object) -> None: + """Client defaults to AUTO batch mode.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-default-mode") + + # No batch_mode arg = AUTO. + with fila.Client(server.addr) as client: + msg_id = client.enqueue("test-default-mode", None, b"default") + assert msg_id != "" + + +class TestBatchModeTypes: + """Unit tests for BatchMode and Linger types (no server needed).""" + + def test_batch_mode_enum(self) -> None: + """BatchMode has AUTO and DISABLED variants.""" + assert fila.BatchMode.AUTO is not None + assert fila.BatchMode.DISABLED is not None + modes = {fila.BatchMode.AUTO, fila.BatchMode.DISABLED} + assert len(modes) == 2 # They are distinct values + + def test_linger_fields(self) -> None: + """Linger stores linger_ms and batch_size.""" + linger = fila.Linger(linger_ms=100, batch_size=50) + assert linger.linger_ms == 100 + assert linger.batch_size == 50 + + def test_batch_enqueue_result_success(self) -> None: + """BatchEnqueueResult.is_success returns True when message_id is set.""" + r = fila.BatchEnqueueResult(message_id="abc", error=None) + assert r.is_success + assert r.message_id == "abc" + assert r.error is None + + def test_batch_enqueue_result_error(self) -> None: + """BatchEnqueueResult.is_success returns False when error is set.""" + r = fila.BatchEnqueueResult(message_id=None, error="queue not found") + assert not r.is_success + assert r.message_id is None + assert r.error == "queue not found" diff --git a/tests/test_batcher.py b/tests/test_batcher.py new file mode 100644 index 0000000..ee10e71 --- /dev/null +++ b/tests/test_batcher.py @@ -0,0 +1,333 @@ +"""Unit tests for the batcher module. + +These tests use mock stubs and do not require a running fila-server. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import Future +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from fila.batcher import ( + AutoBatcher, + LingerBatcher, + _EnqueueRequest, + _flush_batch, + _flush_single, +) +from fila.errors import BatchEnqueueError +from fila.v1 import service_pb2 + + +class FakeEnqueueResponse: + """Minimal fake for service_pb2.EnqueueResponse.""" + + def __init__(self, message_id: str) -> None: + self.message_id = message_id + + +class FakeBatchResult: + """Minimal fake for service_pb2.BatchEnqueueResult.""" + + def __init__(self, message_id: str | None = None, error: str | None = None) -> None: + self._message_id = message_id + self._error = error + self.success: FakeEnqueueResponse | None = ( + FakeEnqueueResponse(message_id) if message_id is not None else None + ) + self.error = error or "" + + def HasField(self, name: str) -> bool: # noqa: N802 + if name == "success": + return self._message_id is not None + return False + + +class FakeBatchResponse: + """Minimal fake for service_pb2.BatchEnqueueResponse.""" + + def __init__(self, results: list[FakeBatchResult]) -> None: + self.results = results + + +class TestFlushSingle: + """Test the _flush_single function.""" + + def test_success(self) -> None: + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("msg-001") + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") + fut: Future[str] = Future() + req = _EnqueueRequest(proto, fut) + + _flush_single(stub, req) + + assert fut.result(timeout=1.0) == "msg-001" + stub.Enqueue.assert_called_once_with(proto) + + def test_rpc_error(self) -> None: + import grpc + + stub = MagicMock() + rpc_error = MagicMock() + rpc_error.code.return_value = grpc.StatusCode.NOT_FOUND + rpc_error.details.return_value = "queue not found" + # Make it pass isinstance(e, grpc.RpcError) check. + stub.Enqueue.side_effect = type( + "_FakeRpcError", (grpc.RpcError,), { + "code": lambda self: grpc.StatusCode.NOT_FOUND, + "details": lambda self: "queue not found", + } + )() + + proto = service_pb2.EnqueueRequest(queue="missing", payload=b"data") + fut: Future[str] = Future() + req = _EnqueueRequest(proto, fut) + + _flush_single(stub, req) + + from fila.errors import QueueNotFoundError + + with pytest.raises(QueueNotFoundError): + fut.result(timeout=1.0) + + +class TestFlushBatch: + """Test the _flush_batch function.""" + + def test_all_success(self) -> None: + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id="id-1"), + FakeBatchResult(message_id="id-2"), + ]) + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + assert reqs[0].future.result(timeout=1.0) == "id-1" + assert reqs[1].future.result(timeout=1.0) == "id-2" + + def test_mixed_results(self) -> None: + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id="id-1"), + FakeBatchResult(error="queue 'missing' not found"), + ]) + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="missing", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + assert reqs[0].future.result(timeout=1.0) == "id-1" + with pytest.raises(BatchEnqueueError, match="queue 'missing' not found"): + reqs[1].future.result(timeout=1.0) + + def test_rpc_failure_sets_all_futures(self) -> None: + import grpc + + stub = MagicMock() + stub.BatchEnqueue.side_effect = type( + "_FakeRpcError", (grpc.RpcError,), { + "code": lambda self: grpc.StatusCode.UNAVAILABLE, + "details": lambda self: "server unavailable", + } + )() + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + for r in reqs: + with pytest.raises(BatchEnqueueError): + r.future.result(timeout=1.0) + + +class TestAutoBatcher: + """Test the AutoBatcher end-to-end.""" + + def test_single_message_uses_enqueue(self) -> None: + """When only one message is queued, AutoBatcher uses singular Enqueue.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("msg-solo") + + batcher = AutoBatcher(stub, max_batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"solo") + fut = batcher.submit(proto) + result = fut.result(timeout=5.0) + + assert result == "msg-solo" + stub.Enqueue.assert_called_once() + stub.BatchEnqueue.assert_not_called() + + batcher.close() + + def test_concurrent_messages_batched(self) -> None: + """When multiple messages arrive concurrently, they batch together.""" + stub = MagicMock() + + # The first message will block Enqueue while more messages queue up. + # We need to make the batcher see all messages at once. + batch_called = threading.Event() + batch_response = FakeBatchResponse([ + FakeBatchResult(message_id=f"id-{i}") for i in range(5) + ]) + + def mock_batch_enqueue(request: Any) -> FakeBatchResponse: + batch_called.set() + return batch_response + + # Make single Enqueue block briefly so messages accumulate. + single_barrier = threading.Event() + + def mock_single_enqueue(request: Any) -> FakeEnqueueResponse: + single_barrier.wait(timeout=5.0) + return FakeEnqueueResponse("should-not-be-used") + + stub.Enqueue.side_effect = mock_single_enqueue + stub.BatchEnqueue.side_effect = mock_batch_enqueue + + batcher = AutoBatcher(stub, max_batch_size=100) + + # Submit 5 messages rapidly before the first can process. + # The batcher should drain them all in one batch. + protos = [ + service_pb2.EnqueueRequest(queue="q", payload=f"msg-{i}".encode()) + for i in range(5) + ] + + # We need to submit them in a way that they all arrive before + # the batcher loop drains. Use a barrier approach. + futures = [] + for p in protos: + futures.append(batcher.submit(p)) + + # Give the batcher thread time to drain and flush. + # Either BatchEnqueue or multiple Enqueue calls will resolve things. + for _i, f in enumerate(futures): + result = f.result(timeout=5.0) + assert result is not None + + batcher.close() + + def test_close_drains_pending(self) -> None: + """close() waits for pending messages to be flushed.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("drained") + + batcher = AutoBatcher(stub, max_batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain-me") + fut = batcher.submit(proto) + + batcher.close() + + # After close, the future should be resolved. + assert fut.result(timeout=1.0) == "drained" + + def test_update_stub(self) -> None: + """update_stub replaces the gRPC stub used for flushing.""" + old_stub = MagicMock() + new_stub = MagicMock() + new_stub.Enqueue.return_value = FakeEnqueueResponse("new-stub") + + batcher = AutoBatcher(old_stub, max_batch_size=100) + + # Update stub before submitting. + batcher.update_stub(new_stub) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") + fut = batcher.submit(proto) + result = fut.result(timeout=5.0) + + assert result == "new-stub" + batcher.close() + + +class TestLingerBatcher: + """Test the LingerBatcher.""" + + def test_flushes_at_batch_size(self) -> None: + """Flush triggers when batch_size messages accumulate.""" + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id=f"id-{i}") for i in range(3) + ]) + + batcher = LingerBatcher(stub, linger_ms=5000, batch_size=3) + + futures = [] + for i in range(3): + proto = service_pb2.EnqueueRequest(queue="q", payload=f"m{i}".encode()) + futures.append(batcher.submit(proto)) + + # Should flush quickly because batch_size=3 was reached. + for i, f in enumerate(futures): + result = f.result(timeout=5.0) + assert result == f"id-{i}" + + batcher.close() + + def test_flushes_at_linger_timeout(self) -> None: + """Flush triggers after linger_ms even if batch_size is not reached.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("lingered") + + batcher = LingerBatcher(stub, linger_ms=50, batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"linger") + fut = batcher.submit(proto) + + # Should flush after ~50ms even though batch_size=100 not reached. + result = fut.result(timeout=5.0) + assert result == "lingered" + + batcher.close() + + def test_close_drains_pending(self) -> None: + """close() drains any pending messages.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("drained") + + batcher = LingerBatcher(stub, linger_ms=10000, batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain") + fut = batcher.submit(proto) + + batcher.close() + + assert fut.result(timeout=1.0) == "drained"