Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,67 @@ async def main():
asyncio.run(main())
```

## TLS

For servers using certificates from a public CA (e.g., Let's Encrypt), enable TLS
with the system trust store:

```python
from fila import Client

# TLS using OS system trust store.
client = Client("localhost:5555", tls=True)
```

For servers using a private CA, provide the CA certificate explicitly:

```python
from fila import Client

# Read certificates.
with open("ca.pem", "rb") as f:
ca_cert = f.read()
with open("client.pem", "rb") as f:
client_cert = f.read()
with open("client-key.pem", "rb") as f:
client_key = f.read()

# TLS with custom CA (server verification).
client = Client("localhost:5555", ca_cert=ca_cert)

# Mutual TLS (client + server verification).
client = Client(
"localhost:5555",
ca_cert=ca_cert,
client_cert=client_cert,
client_key=client_key,
)
```

Note: `ca_cert` implies `tls=True` -- you don't need to pass both.

## API Key Authentication

When the server has API key auth enabled, pass the key to the client:

```python
from fila import Client

client = Client("localhost:5555", api_key="fila_your_api_key_here")

# Combined with TLS:
client = Client(
"localhost:5555",
ca_cert=ca_cert,
api_key="fila_your_api_key_here",
)
```

The API key is sent as `authorization: Bearer <key>` metadata on every RPC.

## API

### `Client(addr)` / `AsyncClient(addr)`
### `Client(addr, *, tls=False, ca_cert=None, client_cert=None, client_key=None, api_key=None)` / `AsyncClient(...)`

Connect to a Fila broker. Both support context manager protocol.

Expand Down
150 changes: 148 additions & 2 deletions fila/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,90 @@
from fila.v1 import service_pb2, service_pb2_grpc
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.


class _AsyncClientCallDetails(
grpc.aio.ClientCallDetails, # type: ignore[misc]
):
"""Concrete ``ClientCallDetails`` for the async interceptor chain.

``grpc.aio.ClientCallDetails`` is a namedtuple with 5 fields (method,
timeout, metadata, credentials, wait_for_ready). We override ``__new__``
so the namedtuple layer receives exactly those five, then set any extra
attribute (``compression``) in ``__init__``.
"""

def __new__(
cls,
method: str,
timeout: float | None,
metadata: grpc.aio.Metadata | None,
credentials: grpc.CallCredentials | None,
wait_for_ready: bool | None,
) -> _AsyncClientCallDetails:
return super().__new__(cls, method, timeout, metadata, credentials, wait_for_ready) # type: ignore[no-any-return]

def __init__(
self,
method: str,
timeout: float | None,
metadata: grpc.aio.Metadata | None,
credentials: grpc.CallCredentials | None,
wait_for_ready: bool | None,
) -> None:
# Fields are already set by __new__ (namedtuple). Nothing extra to do.
pass


class _AsyncApiKeyInterceptor(
grpc.aio.UnaryUnaryClientInterceptor, # type: ignore[misc]
grpc.aio.UnaryStreamClientInterceptor, # type: ignore[misc]
):
"""Injects ``authorization: Bearer <key>`` metadata into every async RPC."""

def __init__(self, api_key: str) -> None:
self._metadata = grpc.aio.Metadata(("authorization", f"Bearer {api_key}"))

def _inject(
self, metadata: grpc.aio.Metadata | None
) -> grpc.aio.Metadata:
merged = grpc.aio.Metadata()
if metadata is not None:
for key, value in metadata:
merged.add(key, value)
for key, value in self._metadata:
merged.add(key, value)
return merged

async def intercept_unary_unary(
self,
continuation: Any,
client_call_details: grpc.aio.ClientCallDetails,
request: Any,
) -> Any:
new_details = _AsyncClientCallDetails(
client_call_details.method,
client_call_details.timeout,
self._inject(client_call_details.metadata),
client_call_details.credentials,
client_call_details.wait_for_ready,
)
return await continuation(new_details, request)

async def intercept_unary_stream(
self,
continuation: Any,
client_call_details: grpc.aio.ClientCallDetails,
request: Any,
) -> Any:
new_details = _AsyncClientCallDetails(
client_call_details.method,
client_call_details.timeout,
self._inject(client_call_details.metadata),
client_call_details.credentials,
client_call_details.wait_for_ready,
)
return await continuation(new_details, request)


class AsyncClient:
"""Asynchronous client for the Fila message broker.

Expand All @@ -32,15 +116,77 @@ class AsyncClient:

async with AsyncClient("localhost:5555") as client:
await client.enqueue("my-queue", None, b"hello")

TLS (system trust store)::

client = AsyncClient("localhost:5555", tls=True)

TLS (custom CA)::

with open("ca.pem", "rb") as f:
ca = f.read()
client = AsyncClient("localhost:5555", ca_cert=ca)

mTLS + API key::

client = AsyncClient(
"localhost:5555",
ca_cert=ca,
client_cert=cert,
client_key=key,
api_key="fila_...",
)
"""

def __init__(self, addr: str) -> None:
def __init__(
self,
addr: str,
*,
tls: bool = False,
ca_cert: bytes | None = None,
client_cert: bytes | None = None,
client_key: bytes | None = None,
api_key: str | None = None,
) -> None:
"""Connect to a Fila broker at the given address.

Args:
addr: Broker address in "host:port" format (e.g., "localhost:5555").
tls: Enable TLS using the OS system trust store for server
verification. Ignored when ``ca_cert`` is provided (which
implies TLS). Defaults to ``False``.
ca_cert: PEM-encoded CA certificate for verifying the server.
When provided, a TLS channel is used instead of an insecure one.
client_cert: PEM-encoded client certificate for mutual TLS (optional).
client_key: PEM-encoded client private key for mutual TLS (optional).
api_key: API key for authentication. When set, every RPC includes an
``authorization: Bearer <key>`` metadata header.
"""
self._channel = grpc.aio.insecure_channel(addr)
use_tls = tls or ca_cert is not None

if (client_cert is not None or client_key is not None) and not use_tls:
raise ValueError(
"client_cert and client_key require ca_cert or tls=True to establish a TLS channel"
)

interceptors: list[grpc.aio.ClientInterceptor] = []
if api_key is not None:
interceptors.append(_AsyncApiKeyInterceptor(api_key))

if use_tls:
creds = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=client_key,
certificate_chain=client_cert,
)
self._channel = grpc.aio.secure_channel(
addr, creds, interceptors=interceptors or None
)
else:
self._channel = grpc.aio.insecure_channel(
addr, interceptors=interceptors or None
)

self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]

async def close(self) -> None:
Expand Down
128 changes: 126 additions & 2 deletions fila/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,72 @@
from collections.abc import Iterator


class _ClientCallDetails(
grpc.ClientCallDetails, # type: ignore[misc]
):
"""Concrete ``ClientCallDetails`` that can be instantiated.

``grpc.ClientCallDetails`` is an abstract class with no ``__init__``, so we
need our own subclass to carry the fields through the interceptor chain.
"""

def __init__(
self,
method: str,
timeout: float | None,
metadata: list[tuple[str, str | bytes]] | None,
credentials: grpc.CallCredentials | None,
wait_for_ready: bool | None,
compression: grpc.Compression | None,
) -> None:
self.method = method
self.timeout = timeout
self.metadata = metadata
self.credentials = credentials
self.wait_for_ready = wait_for_ready
self.compression = compression


class _ApiKeyInterceptor(
grpc.UnaryUnaryClientInterceptor, # type: ignore[misc]
grpc.UnaryStreamClientInterceptor, # type: ignore[misc]
):
"""Injects ``authorization: Bearer <key>`` metadata into every RPC."""

def __init__(self, api_key: str) -> None:
self._metadata = (("authorization", f"Bearer {api_key}"),)

def _inject(
self, client_call_details: grpc.ClientCallDetails
) -> _ClientCallDetails:
metadata = list(client_call_details.metadata or [])
metadata.extend(self._metadata)
return _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression,
)

def intercept_unary_unary(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._inject(client_call_details), request)

def intercept_unary_stream(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._inject(client_call_details), request)


class Client:
"""Synchronous client for the Fila message broker.

Expand All @@ -31,15 +97,73 @@ class Client:

with Client("localhost:5555") as client:
client.enqueue("my-queue", None, b"hello")

TLS (system trust store)::

client = Client("localhost:5555", tls=True)

TLS (custom CA)::

with open("ca.pem", "rb") as f:
ca = f.read()
client = Client("localhost:5555", ca_cert=ca)

mTLS + API key::

client = Client(
"localhost:5555",
ca_cert=ca,
client_cert=cert,
client_key=key,
api_key="fila_...",
)
"""

def __init__(self, addr: str) -> None:
def __init__(
self,
addr: str,
*,
tls: bool = False,
ca_cert: bytes | None = None,
client_cert: bytes | None = None,
client_key: bytes | None = None,
api_key: str | None = None,
) -> None:
"""Connect to a Fila broker at the given address.

Args:
addr: Broker address in "host:port" format (e.g., "localhost:5555").
tls: Enable TLS using the OS system trust store for server
verification. Ignored when ``ca_cert`` is provided (which
implies TLS). Defaults to ``False``.
ca_cert: PEM-encoded CA certificate for verifying the server.
When provided, a TLS channel is used instead of an insecure one.
client_cert: PEM-encoded client certificate for mutual TLS (optional).
client_key: PEM-encoded client private key for mutual TLS (optional).
api_key: API key for authentication. When set, every RPC includes an
``authorization: Bearer <key>`` metadata header.
"""
self._channel = grpc.insecure_channel(addr)
use_tls = tls or ca_cert is not None

if (client_cert is not None or client_key is not None) and not use_tls:
raise ValueError(
"client_cert and client_key require ca_cert or tls=True to establish a TLS channel"
)

if use_tls:
creds = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=client_key,
certificate_chain=client_cert,
)
self._channel = grpc.secure_channel(addr, creds)
else:
self._channel = grpc.insecure_channel(addr)

if api_key is not None:
interceptor = _ApiKeyInterceptor(api_key)
self._channel = grpc.intercept_channel(self._channel, interceptor)

self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call]

def close(self) -> None:
Expand Down
Loading
Loading