Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions fila/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ async def intercept_unary_stream(
return await continuation(new_details, request)


_LEADER_HINT_KEY = "x-fila-leader-addr"


def _extract_leader_hint(err: grpc.RpcError) -> str | None:
"""Return the leader address from trailing metadata, if present."""
if err.code() != grpc.StatusCode.UNAVAILABLE:
return None
trailing = err.trailing_metadata()
if trailing is None:
return None
for key, value in trailing:
if key == _LEADER_HINT_KEY:
return str(value)
return None


class AsyncClient:
"""Asynchronous client for the Fila message broker.

Expand Down Expand Up @@ -162,32 +178,41 @@ def __init__(
api_key: API key for authentication. When set, every RPC includes an
``authorization: Bearer <key>`` metadata header.
"""
use_tls = tls or ca_cert is not None
self._tls = tls
self._ca_cert = ca_cert
self._client_cert = client_cert
self._client_key = client_key
self._api_key = api_key

use_tls = tls or ca_cert is not None
if (client_cert is not None or client_key is not None) and not use_tls:
raise ValueError(
"client_cert and client_key require ca_cert or tls=True to establish a TLS channel"
)

self._channel = self._make_channel(addr)
self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]

def _make_channel(self, addr: str) -> grpc.aio.Channel:
"""Create an async gRPC channel to the given address using stored credentials."""
use_tls = self._tls or self._ca_cert is not None

interceptors: list[grpc.aio.ClientInterceptor] = []
if api_key is not None:
interceptors.append(_AsyncApiKeyInterceptor(api_key))
if self._api_key is not None:
interceptors.append(_AsyncApiKeyInterceptor(self._api_key))

if use_tls:
creds = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=client_key,
certificate_chain=client_cert,
root_certificates=self._ca_cert,
private_key=self._client_key,
certificate_chain=self._client_cert,
)
self._channel = grpc.aio.secure_channel(
return grpc.aio.secure_channel(
addr, creds, interceptors=interceptors or None
)
else:
self._channel = grpc.aio.insecure_channel(
addr, interceptors=interceptors or None
)

self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]
return grpc.aio.insecure_channel(
addr, interceptors=interceptors or None
)

async def close(self) -> None:
"""Close the underlying gRPC channel."""
Expand Down Expand Up @@ -238,6 +263,10 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]:
server stream closes or an error occurs. Nil message frames (keepalive
signals) are skipped automatically.

If the server returns UNAVAILABLE with an ``x-fila-leader-addr``
trailing metadata entry, the client transparently reconnects to the
leader address and retries the consume call once.

Args:
queue: Queue to consume from.

Expand All @@ -253,10 +282,26 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]:
service_pb2.ConsumeRequest(queue=queue)
)
except grpc.RpcError as e:
raise _map_consume_error(e) from e
leader_addr = _extract_leader_hint(e)
if leader_addr is not None:
stream = await self._reconnect_and_consume(leader_addr, queue)
else:
raise _map_consume_error(e) from e

return self._consume_iter(stream)

async def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any:
"""Create a new channel to *leader_addr* and retry the consume call."""
await self._channel.close()
self._channel = self._make_channel(leader_addr)
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]
try:
return self._stub.Consume(
service_pb2.ConsumeRequest(queue=queue)
)
except grpc.RpcError as e:
raise _map_consume_error(e) from e

async def _consume_iter(
self,
stream: Any,
Expand Down
73 changes: 62 additions & 11 deletions fila/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
if TYPE_CHECKING:
from collections.abc import Iterator

_LEADER_HINT_KEY = "x-fila-leader-addr"


def _extract_leader_hint(err: grpc.RpcError) -> str | None:
"""Return the leader address from trailing metadata, if present.

The server sets ``x-fila-leader-addr`` in trailing metadata alongside an
UNAVAILABLE status when the node is not the leader for the requested queue.
"""
if err.code() != grpc.StatusCode.UNAVAILABLE:
return None
trailing = err.trailing_metadata()
if trailing is None:
return None
for key, value in trailing:
if key == _LEADER_HINT_KEY:
return str(value)
return None


class _ClientCallDetails(
grpc.ClientCallDetails, # type: ignore[misc]
Expand Down Expand Up @@ -143,28 +162,40 @@ def __init__(
api_key: API key for authentication. When set, every RPC includes an
``authorization: Bearer <key>`` metadata header.
"""
use_tls = tls or ca_cert is not None
self._tls = tls
self._ca_cert = ca_cert
self._client_cert = client_cert
self._client_key = client_key
self._api_key = api_key

use_tls = tls or ca_cert is not None
if (client_cert is not None or client_key is not None) and not use_tls:
raise ValueError(
"client_cert and client_key require ca_cert or tls=True to establish a TLS channel"
)

self._channel = self._make_channel(addr)
self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]

def _make_channel(self, addr: str) -> grpc.Channel:
"""Create a gRPC channel to the given address using stored credentials."""
use_tls = self._tls or self._ca_cert is not None

if use_tls:
creds = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=client_key,
certificate_chain=client_cert,
root_certificates=self._ca_cert,
private_key=self._client_key,
certificate_chain=self._client_cert,
)
self._channel = grpc.secure_channel(addr, creds)
channel: grpc.Channel = grpc.secure_channel(addr, creds)
else:
self._channel = grpc.insecure_channel(addr)
channel = grpc.insecure_channel(addr)

if api_key is not None:
interceptor = _ApiKeyInterceptor(api_key)
self._channel = grpc.intercept_channel(self._channel, interceptor)
if self._api_key is not None:
interceptor = _ApiKeyInterceptor(self._api_key)
channel = grpc.intercept_channel(channel, interceptor)

self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]
return channel

def close(self) -> None:
"""Close the underlying gRPC channel."""
Expand Down Expand Up @@ -215,6 +246,10 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]:
server stream closes or an error occurs. Skip nil message frames
(keepalive signals) automatically.

If the server returns UNAVAILABLE with an ``x-fila-leader-addr``
trailing metadata entry, the client transparently reconnects to the
leader address and retries the consume call once.

Args:
queue: Queue to consume from.

Expand All @@ -230,10 +265,26 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]:
service_pb2.ConsumeRequest(queue=queue)
)
except grpc.RpcError as e:
raise _map_consume_error(e) from e
leader_addr = _extract_leader_hint(e)
if leader_addr is not None:
stream = self._reconnect_and_consume(leader_addr, queue)
else:
raise _map_consume_error(e) from e

return self._consume_iter(stream)

def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any:
"""Create a new channel to *leader_addr* and retry the consume call."""
self._channel.close()
self._channel = self._make_channel(leader_addr)
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]
try:
return self._stub.Consume(
service_pb2.ConsumeRequest(queue=queue)
)
except grpc.RpcError as e:
raise _map_consume_error(e) from e

def _consume_iter(
self,
stream: Any,
Expand Down
Loading