diff --git a/fila/async_client.py b/fila/async_client.py index 9c99b50..8bab27b 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -99,6 +99,22 @@ async def intercept_unary_stream( return await continuation(new_details, request) +_LEADER_HINT_KEY = "x-fila-leader-addr" + + +def _extract_leader_hint(err: grpc.RpcError) -> str | None: + """Return the leader address from trailing metadata, if present.""" + if err.code() != grpc.StatusCode.UNAVAILABLE: + return None + trailing = err.trailing_metadata() + if trailing is None: + return None + for key, value in trailing: + if key == _LEADER_HINT_KEY: + return str(value) + return None + + class AsyncClient: """Asynchronous client for the Fila message broker. @@ -162,32 +178,41 @@ def __init__( api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. """ - use_tls = tls or ca_cert is not None + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) + self._channel = self._make_channel(addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + + def _make_channel(self, addr: str) -> grpc.aio.Channel: + """Create an async gRPC channel to the given address using stored credentials.""" + use_tls = self._tls or self._ca_cert is not None + interceptors: list[grpc.aio.ClientInterceptor] = [] - if api_key is not None: - interceptors.append(_AsyncApiKeyInterceptor(api_key)) + if self._api_key is not None: + interceptors.append(_AsyncApiKeyInterceptor(self._api_key)) if use_tls: creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, + root_certificates=self._ca_cert, + private_key=self._client_key, + certificate_chain=self._client_cert, ) - self._channel = grpc.aio.secure_channel( + return grpc.aio.secure_channel( addr, creds, interceptors=interceptors or None ) - else: - self._channel = grpc.aio.insecure_channel( - addr, interceptors=interceptors or None - ) - - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + return grpc.aio.insecure_channel( + addr, interceptors=interceptors or None + ) async def close(self) -> None: """Close the underlying gRPC channel.""" @@ -238,6 +263,10 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: server stream closes or an error occurs. Nil message frames (keepalive signals) are skipped automatically. + If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` + trailing metadata entry, the client transparently reconnects to the + leader address and retries the consume call once. + Args: queue: Queue to consume from. @@ -253,10 +282,26 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: service_pb2.ConsumeRequest(queue=queue) ) except grpc.RpcError as e: - raise _map_consume_error(e) from e + leader_addr = _extract_leader_hint(e) + if leader_addr is not None: + stream = await self._reconnect_and_consume(leader_addr, queue) + else: + raise _map_consume_error(e) from e return self._consume_iter(stream) + async def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: + """Create a new channel to *leader_addr* and retry the consume call.""" + await self._channel.close() + self._channel = self._make_channel(leader_addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + try: + return self._stub.Consume( + service_pb2.ConsumeRequest(queue=queue) + ) + except grpc.RpcError as e: + raise _map_consume_error(e) from e + async def _consume_iter( self, stream: Any, diff --git a/fila/client.py b/fila/client.py index 531c051..891907a 100644 --- a/fila/client.py +++ b/fila/client.py @@ -13,6 +13,25 @@ if TYPE_CHECKING: from collections.abc import Iterator +_LEADER_HINT_KEY = "x-fila-leader-addr" + + +def _extract_leader_hint(err: grpc.RpcError) -> str | None: + """Return the leader address from trailing metadata, if present. + + The server sets ``x-fila-leader-addr`` in trailing metadata alongside an + UNAVAILABLE status when the node is not the leader for the requested queue. + """ + if err.code() != grpc.StatusCode.UNAVAILABLE: + return None + trailing = err.trailing_metadata() + if trailing is None: + return None + for key, value in trailing: + if key == _LEADER_HINT_KEY: + return str(value) + return None + class _ClientCallDetails( grpc.ClientCallDetails, # type: ignore[misc] @@ -143,28 +162,40 @@ def __init__( api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. """ - use_tls = tls or ca_cert is not None + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) + self._channel = self._make_channel(addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + + def _make_channel(self, addr: str) -> grpc.Channel: + """Create a gRPC channel to the given address using stored credentials.""" + use_tls = self._tls or self._ca_cert is not None + if use_tls: creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, + root_certificates=self._ca_cert, + private_key=self._client_key, + certificate_chain=self._client_cert, ) - self._channel = grpc.secure_channel(addr, creds) + channel: grpc.Channel = grpc.secure_channel(addr, creds) else: - self._channel = grpc.insecure_channel(addr) + channel = grpc.insecure_channel(addr) - if api_key is not None: - interceptor = _ApiKeyInterceptor(api_key) - self._channel = grpc.intercept_channel(self._channel, interceptor) + if self._api_key is not None: + interceptor = _ApiKeyInterceptor(self._api_key) + channel = grpc.intercept_channel(channel, interceptor) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + return channel def close(self) -> None: """Close the underlying gRPC channel.""" @@ -215,6 +246,10 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: server stream closes or an error occurs. Skip nil message frames (keepalive signals) automatically. + If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` + trailing metadata entry, the client transparently reconnects to the + leader address and retries the consume call once. + Args: queue: Queue to consume from. @@ -230,10 +265,26 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: service_pb2.ConsumeRequest(queue=queue) ) except grpc.RpcError as e: - raise _map_consume_error(e) from e + leader_addr = _extract_leader_hint(e) + if leader_addr is not None: + stream = self._reconnect_and_consume(leader_addr, queue) + else: + raise _map_consume_error(e) from e return self._consume_iter(stream) + def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: + """Create a new channel to *leader_addr* and retry the consume call.""" + self._channel.close() + self._channel = self._make_channel(leader_addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + try: + return self._stub.Consume( + service_pb2.ConsumeRequest(queue=queue) + ) + except grpc.RpcError as e: + raise _map_consume_error(e) from e + def _consume_iter( self, stream: Any,