diff --git a/README.md b/README.md index 8a4d189..ad70043 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,67 @@ async def main(): asyncio.run(main()) ``` +## TLS + +For servers using certificates from a public CA (e.g., Let's Encrypt), enable TLS +with the system trust store: + +```python +from fila import Client + +# TLS using OS system trust store. +client = Client("localhost:5555", tls=True) +``` + +For servers using a private CA, provide the CA certificate explicitly: + +```python +from fila import Client + +# Read certificates. +with open("ca.pem", "rb") as f: + ca_cert = f.read() +with open("client.pem", "rb") as f: + client_cert = f.read() +with open("client-key.pem", "rb") as f: + client_key = f.read() + +# TLS with custom CA (server verification). +client = Client("localhost:5555", ca_cert=ca_cert) + +# Mutual TLS (client + server verification). +client = Client( + "localhost:5555", + ca_cert=ca_cert, + client_cert=client_cert, + client_key=client_key, +) +``` + +Note: `ca_cert` implies `tls=True` -- you don't need to pass both. + +## API Key Authentication + +When the server has API key auth enabled, pass the key to the client: + +```python +from fila import Client + +client = Client("localhost:5555", api_key="fila_your_api_key_here") + +# Combined with TLS: +client = Client( + "localhost:5555", + ca_cert=ca_cert, + api_key="fila_your_api_key_here", +) +``` + +The API key is sent as `authorization: Bearer ` metadata on every RPC. + ## API -### `Client(addr)` / `AsyncClient(addr)` +### `Client(addr, *, tls=False, ca_cert=None, client_cert=None, client_key=None, api_key=None)` / `AsyncClient(...)` Connect to a Fila broker. Both support context manager protocol. diff --git a/fila/async_client.py b/fila/async_client.py index 2cbcff2..9c99b50 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -15,6 +15,90 @@ from fila.v1 import service_pb2, service_pb2_grpc +class _AsyncClientCallDetails( + grpc.aio.ClientCallDetails, # type: ignore[misc] +): + """Concrete ``ClientCallDetails`` for the async interceptor chain. + + ``grpc.aio.ClientCallDetails`` is a namedtuple with 5 fields (method, + timeout, metadata, credentials, wait_for_ready). We override ``__new__`` + so the namedtuple layer receives exactly those five, then set any extra + attribute (``compression``) in ``__init__``. + """ + + def __new__( + cls, + method: str, + timeout: float | None, + metadata: grpc.aio.Metadata | None, + credentials: grpc.CallCredentials | None, + wait_for_ready: bool | None, + ) -> _AsyncClientCallDetails: + return super().__new__(cls, method, timeout, metadata, credentials, wait_for_ready) # type: ignore[no-any-return] + + def __init__( + self, + method: str, + timeout: float | None, + metadata: grpc.aio.Metadata | None, + credentials: grpc.CallCredentials | None, + wait_for_ready: bool | None, + ) -> None: + # Fields are already set by __new__ (namedtuple). Nothing extra to do. + pass + + +class _AsyncApiKeyInterceptor( + grpc.aio.UnaryUnaryClientInterceptor, # type: ignore[misc] + grpc.aio.UnaryStreamClientInterceptor, # type: ignore[misc] +): + """Injects ``authorization: Bearer `` metadata into every async RPC.""" + + def __init__(self, api_key: str) -> None: + self._metadata = grpc.aio.Metadata(("authorization", f"Bearer {api_key}")) + + def _inject( + self, metadata: grpc.aio.Metadata | None + ) -> grpc.aio.Metadata: + merged = grpc.aio.Metadata() + if metadata is not None: + for key, value in metadata: + merged.add(key, value) + for key, value in self._metadata: + merged.add(key, value) + return merged + + async def intercept_unary_unary( + self, + continuation: Any, + client_call_details: grpc.aio.ClientCallDetails, + request: Any, + ) -> Any: + new_details = _AsyncClientCallDetails( + client_call_details.method, + client_call_details.timeout, + self._inject(client_call_details.metadata), + client_call_details.credentials, + client_call_details.wait_for_ready, + ) + return await continuation(new_details, request) + + async def intercept_unary_stream( + self, + continuation: Any, + client_call_details: grpc.aio.ClientCallDetails, + request: Any, + ) -> Any: + new_details = _AsyncClientCallDetails( + client_call_details.method, + client_call_details.timeout, + self._inject(client_call_details.metadata), + client_call_details.credentials, + client_call_details.wait_for_ready, + ) + return await continuation(new_details, request) + + class AsyncClient: """Asynchronous client for the Fila message broker. @@ -32,15 +116,77 @@ class AsyncClient: async with AsyncClient("localhost:5555") as client: await client.enqueue("my-queue", None, b"hello") + + TLS (system trust store):: + + client = AsyncClient("localhost:5555", tls=True) + + TLS (custom CA):: + + with open("ca.pem", "rb") as f: + ca = f.read() + client = AsyncClient("localhost:5555", ca_cert=ca) + + mTLS + API key:: + + client = AsyncClient( + "localhost:5555", + ca_cert=ca, + client_cert=cert, + client_key=key, + api_key="fila_...", + ) """ - def __init__(self, addr: str) -> None: + def __init__( + self, + addr: str, + *, + tls: bool = False, + ca_cert: bytes | None = None, + client_cert: bytes | None = None, + client_key: bytes | None = None, + api_key: str | None = None, + ) -> None: """Connect to a Fila broker at the given address. Args: addr: Broker address in "host:port" format (e.g., "localhost:5555"). + tls: Enable TLS using the OS system trust store for server + verification. Ignored when ``ca_cert`` is provided (which + implies TLS). Defaults to ``False``. + ca_cert: PEM-encoded CA certificate for verifying the server. + When provided, a TLS channel is used instead of an insecure one. + client_cert: PEM-encoded client certificate for mutual TLS (optional). + client_key: PEM-encoded client private key for mutual TLS (optional). + api_key: API key for authentication. When set, every RPC includes an + ``authorization: Bearer `` metadata header. """ - self._channel = grpc.aio.insecure_channel(addr) + use_tls = tls or ca_cert is not None + + if (client_cert is not None or client_key is not None) and not use_tls: + raise ValueError( + "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + ) + + interceptors: list[grpc.aio.ClientInterceptor] = [] + if api_key is not None: + interceptors.append(_AsyncApiKeyInterceptor(api_key)) + + if use_tls: + creds = grpc.ssl_channel_credentials( + root_certificates=ca_cert, + private_key=client_key, + certificate_chain=client_cert, + ) + self._channel = grpc.aio.secure_channel( + addr, creds, interceptors=interceptors or None + ) + else: + self._channel = grpc.aio.insecure_channel( + addr, interceptors=interceptors or None + ) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] async def close(self) -> None: diff --git a/fila/client.py b/fila/client.py index 3123c30..531c051 100644 --- a/fila/client.py +++ b/fila/client.py @@ -14,6 +14,72 @@ from collections.abc import Iterator +class _ClientCallDetails( + grpc.ClientCallDetails, # type: ignore[misc] +): + """Concrete ``ClientCallDetails`` that can be instantiated. + + ``grpc.ClientCallDetails`` is an abstract class with no ``__init__``, so we + need our own subclass to carry the fields through the interceptor chain. + """ + + def __init__( + self, + method: str, + timeout: float | None, + metadata: list[tuple[str, str | bytes]] | None, + credentials: grpc.CallCredentials | None, + wait_for_ready: bool | None, + compression: grpc.Compression | None, + ) -> None: + self.method = method + self.timeout = timeout + self.metadata = metadata + self.credentials = credentials + self.wait_for_ready = wait_for_ready + self.compression = compression + + +class _ApiKeyInterceptor( + grpc.UnaryUnaryClientInterceptor, # type: ignore[misc] + grpc.UnaryStreamClientInterceptor, # type: ignore[misc] +): + """Injects ``authorization: Bearer `` metadata into every RPC.""" + + def __init__(self, api_key: str) -> None: + self._metadata = (("authorization", f"Bearer {api_key}"),) + + def _inject( + self, client_call_details: grpc.ClientCallDetails + ) -> _ClientCallDetails: + metadata = list(client_call_details.metadata or []) + metadata.extend(self._metadata) + return _ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression, + ) + + def intercept_unary_unary( + self, + continuation: Any, + client_call_details: grpc.ClientCallDetails, + request: Any, + ) -> Any: + return continuation(self._inject(client_call_details), request) + + def intercept_unary_stream( + self, + continuation: Any, + client_call_details: grpc.ClientCallDetails, + request: Any, + ) -> Any: + return continuation(self._inject(client_call_details), request) + + class Client: """Synchronous client for the Fila message broker. @@ -31,15 +97,73 @@ class Client: with Client("localhost:5555") as client: client.enqueue("my-queue", None, b"hello") + + TLS (system trust store):: + + client = Client("localhost:5555", tls=True) + + TLS (custom CA):: + + with open("ca.pem", "rb") as f: + ca = f.read() + client = Client("localhost:5555", ca_cert=ca) + + mTLS + API key:: + + client = Client( + "localhost:5555", + ca_cert=ca, + client_cert=cert, + client_key=key, + api_key="fila_...", + ) """ - def __init__(self, addr: str) -> None: + def __init__( + self, + addr: str, + *, + tls: bool = False, + ca_cert: bytes | None = None, + client_cert: bytes | None = None, + client_key: bytes | None = None, + api_key: str | None = None, + ) -> None: """Connect to a Fila broker at the given address. Args: addr: Broker address in "host:port" format (e.g., "localhost:5555"). + tls: Enable TLS using the OS system trust store for server + verification. Ignored when ``ca_cert`` is provided (which + implies TLS). Defaults to ``False``. + ca_cert: PEM-encoded CA certificate for verifying the server. + When provided, a TLS channel is used instead of an insecure one. + client_cert: PEM-encoded client certificate for mutual TLS (optional). + client_key: PEM-encoded client private key for mutual TLS (optional). + api_key: API key for authentication. When set, every RPC includes an + ``authorization: Bearer `` metadata header. """ - self._channel = grpc.insecure_channel(addr) + use_tls = tls or ca_cert is not None + + if (client_cert is not None or client_key is not None) and not use_tls: + raise ValueError( + "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + ) + + if use_tls: + creds = grpc.ssl_channel_credentials( + root_certificates=ca_cert, + private_key=client_key, + certificate_chain=client_cert, + ) + self._channel = grpc.secure_channel(addr, creds) + else: + self._channel = grpc.insecure_channel(addr) + + if api_key is not None: + interceptor = _ApiKeyInterceptor(api_key) + self._channel = grpc.intercept_channel(self._channel, interceptor) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] def close(self) -> None: diff --git a/fila/v1/admin_pb2.py b/fila/v1/admin_pb2.py index 36d76ad..4bb4e27 100644 --- a/fila/v1/admin_pb2.py +++ b/fila/v1/admin_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\xec\x01\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"U\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\"8\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo2\xb4\x04\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\x9f\x02\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\x12\x16\n\x0eleader_node_id\x18\x08 \x01(\x04\x12\x19\n\x11replication_count\x18\t \x01(\r\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"m\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x16\n\x0eleader_node_id\x18\x05 \x01(\x04\"T\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo\x12\x1a\n\x12\x63luster_node_count\x18\x02 \x01(\r\"Q\n\x13\x43reateApiKeyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\rexpires_at_ms\x18\x02 \x01(\x04\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"J\n\x14\x43reateApiKeyResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"%\n\x13RevokeApiKeyRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"\x16\n\x14RevokeApiKeyResponse\"\x14\n\x12ListApiKeysRequest\"o\n\nApiKeyInfo\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rcreated_at_ms\x18\x03 \x01(\x04\x12\x15\n\rexpires_at_ms\x18\x04 \x01(\x04\x12\x15\n\ris_superadmin\x18\x05 \x01(\x08\"8\n\x13ListApiKeysResponse\x12!\n\x04keys\x18\x01 \x03(\x0b\x32\x13.fila.v1.ApiKeyInfo\".\n\rAclPermission\x12\x0c\n\x04kind\x18\x01 \x01(\t\x12\x0f\n\x07pattern\x18\x02 \x01(\t\"L\n\rSetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\"\x10\n\x0eSetAclResponse\"\x1f\n\rGetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"d\n\x0eGetAclResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\x32\x8e\x07\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponse\x12K\n\x0c\x43reateApiKey\x12\x1c.fila.v1.CreateApiKeyRequest\x1a\x1d.fila.v1.CreateApiKeyResponse\x12K\n\x0cRevokeApiKey\x12\x1c.fila.v1.RevokeApiKeyRequest\x1a\x1d.fila.v1.RevokeApiKeyResponse\x12H\n\x0bListApiKeys\x12\x1b.fila.v1.ListApiKeysRequest\x1a\x1c.fila.v1.ListApiKeysResponse\x12\x39\n\x06SetAcl\x12\x16.fila.v1.SetAclRequest\x1a\x17.fila.v1.SetAclResponse\x12\x39\n\x06GetAcl\x12\x16.fila.v1.GetAclRequest\x1a\x17.fila.v1.GetAclResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -62,17 +62,41 @@ _globals['_PERTHROTTLEKEYSTATS']._serialized_start=741 _globals['_PERTHROTTLEKEYSTATS']._serialized_end=831 _globals['_GETSTATSRESPONSE']._serialized_start=834 - _globals['_GETSTATSRESPONSE']._serialized_end=1070 - _globals['_REDRIVEREQUEST']._serialized_start=1072 - _globals['_REDRIVEREQUEST']._serialized_end=1122 - _globals['_REDRIVERESPONSE']._serialized_start=1124 - _globals['_REDRIVERESPONSE']._serialized_end=1159 - _globals['_LISTQUEUESREQUEST']._serialized_start=1161 - _globals['_LISTQUEUESREQUEST']._serialized_end=1180 - _globals['_QUEUEINFO']._serialized_start=1182 - _globals['_QUEUEINFO']._serialized_end=1267 - _globals['_LISTQUEUESRESPONSE']._serialized_start=1269 - _globals['_LISTQUEUESRESPONSE']._serialized_end=1325 - _globals['_FILAADMIN']._serialized_start=1328 - _globals['_FILAADMIN']._serialized_end=1892 + _globals['_GETSTATSRESPONSE']._serialized_end=1121 + _globals['_REDRIVEREQUEST']._serialized_start=1123 + _globals['_REDRIVEREQUEST']._serialized_end=1173 + _globals['_REDRIVERESPONSE']._serialized_start=1175 + _globals['_REDRIVERESPONSE']._serialized_end=1210 + _globals['_LISTQUEUESREQUEST']._serialized_start=1212 + _globals['_LISTQUEUESREQUEST']._serialized_end=1231 + _globals['_QUEUEINFO']._serialized_start=1233 + _globals['_QUEUEINFO']._serialized_end=1342 + _globals['_LISTQUEUESRESPONSE']._serialized_start=1344 + _globals['_LISTQUEUESRESPONSE']._serialized_end=1428 + _globals['_CREATEAPIKEYREQUEST']._serialized_start=1430 + _globals['_CREATEAPIKEYREQUEST']._serialized_end=1511 + _globals['_CREATEAPIKEYRESPONSE']._serialized_start=1513 + _globals['_CREATEAPIKEYRESPONSE']._serialized_end=1587 + _globals['_REVOKEAPIKEYREQUEST']._serialized_start=1589 + _globals['_REVOKEAPIKEYREQUEST']._serialized_end=1626 + _globals['_REVOKEAPIKEYRESPONSE']._serialized_start=1628 + _globals['_REVOKEAPIKEYRESPONSE']._serialized_end=1650 + _globals['_LISTAPIKEYSREQUEST']._serialized_start=1652 + _globals['_LISTAPIKEYSREQUEST']._serialized_end=1672 + _globals['_APIKEYINFO']._serialized_start=1674 + _globals['_APIKEYINFO']._serialized_end=1785 + _globals['_LISTAPIKEYSRESPONSE']._serialized_start=1787 + _globals['_LISTAPIKEYSRESPONSE']._serialized_end=1843 + _globals['_ACLPERMISSION']._serialized_start=1845 + _globals['_ACLPERMISSION']._serialized_end=1891 + _globals['_SETACLREQUEST']._serialized_start=1893 + _globals['_SETACLREQUEST']._serialized_end=1969 + _globals['_SETACLRESPONSE']._serialized_start=1971 + _globals['_SETACLRESPONSE']._serialized_end=1987 + _globals['_GETACLREQUEST']._serialized_start=1989 + _globals['_GETACLREQUEST']._serialized_end=2020 + _globals['_GETACLRESPONSE']._serialized_start=2022 + _globals['_GETACLRESPONSE']._serialized_end=2122 + _globals['_FILAADMIN']._serialized_start=2125 + _globals['_FILAADMIN']._serialized_end=3035 # @@protoc_insertion_point(module_scope) diff --git a/fila/v1/admin_pb2.pyi b/fila/v1/admin_pb2.pyi index fe54055..d603b29 100644 --- a/fila/v1/admin_pb2.pyi +++ b/fila/v1/admin_pb2.pyi @@ -117,7 +117,7 @@ class PerThrottleKeyStats(_message.Message): def __init__(self, key: _Optional[str] = ..., tokens: _Optional[float] = ..., rate_per_second: _Optional[float] = ..., burst: _Optional[float] = ...) -> None: ... class GetStatsResponse(_message.Message): - __slots__ = ("depth", "in_flight", "active_fairness_keys", "active_consumers", "quantum", "per_key_stats", "per_throttle_stats") + __slots__ = ("depth", "in_flight", "active_fairness_keys", "active_consumers", "quantum", "per_key_stats", "per_throttle_stats", "leader_node_id", "replication_count") DEPTH_FIELD_NUMBER: _ClassVar[int] IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] ACTIVE_FAIRNESS_KEYS_FIELD_NUMBER: _ClassVar[int] @@ -125,6 +125,8 @@ class GetStatsResponse(_message.Message): QUANTUM_FIELD_NUMBER: _ClassVar[int] PER_KEY_STATS_FIELD_NUMBER: _ClassVar[int] PER_THROTTLE_STATS_FIELD_NUMBER: _ClassVar[int] + LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] + REPLICATION_COUNT_FIELD_NUMBER: _ClassVar[int] depth: int in_flight: int active_fairness_keys: int @@ -132,7 +134,9 @@ class GetStatsResponse(_message.Message): quantum: int per_key_stats: _containers.RepeatedCompositeFieldContainer[PerFairnessKeyStats] per_throttle_stats: _containers.RepeatedCompositeFieldContainer[PerThrottleKeyStats] - def __init__(self, depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_fairness_keys: _Optional[int] = ..., active_consumers: _Optional[int] = ..., quantum: _Optional[int] = ..., per_key_stats: _Optional[_Iterable[_Union[PerFairnessKeyStats, _Mapping]]] = ..., per_throttle_stats: _Optional[_Iterable[_Union[PerThrottleKeyStats, _Mapping]]] = ...) -> None: ... + leader_node_id: int + replication_count: int + def __init__(self, depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_fairness_keys: _Optional[int] = ..., active_consumers: _Optional[int] = ..., quantum: _Optional[int] = ..., per_key_stats: _Optional[_Iterable[_Union[PerFairnessKeyStats, _Mapping]]] = ..., per_throttle_stats: _Optional[_Iterable[_Union[PerThrottleKeyStats, _Mapping]]] = ..., leader_node_id: _Optional[int] = ..., replication_count: _Optional[int] = ...) -> None: ... class RedriveRequest(_message.Message): __slots__ = ("dlq_queue", "count") @@ -153,19 +157,113 @@ class ListQueuesRequest(_message.Message): def __init__(self) -> None: ... class QueueInfo(_message.Message): - __slots__ = ("name", "depth", "in_flight", "active_consumers") + __slots__ = ("name", "depth", "in_flight", "active_consumers", "leader_node_id") NAME_FIELD_NUMBER: _ClassVar[int] DEPTH_FIELD_NUMBER: _ClassVar[int] IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] ACTIVE_CONSUMERS_FIELD_NUMBER: _ClassVar[int] + LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] name: str depth: int in_flight: int active_consumers: int - def __init__(self, name: _Optional[str] = ..., depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_consumers: _Optional[int] = ...) -> None: ... + leader_node_id: int + def __init__(self, name: _Optional[str] = ..., depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_consumers: _Optional[int] = ..., leader_node_id: _Optional[int] = ...) -> None: ... class ListQueuesResponse(_message.Message): - __slots__ = ("queues",) + __slots__ = ("queues", "cluster_node_count") QUEUES_FIELD_NUMBER: _ClassVar[int] + CLUSTER_NODE_COUNT_FIELD_NUMBER: _ClassVar[int] queues: _containers.RepeatedCompositeFieldContainer[QueueInfo] - def __init__(self, queues: _Optional[_Iterable[_Union[QueueInfo, _Mapping]]] = ...) -> None: ... + cluster_node_count: int + def __init__(self, queues: _Optional[_Iterable[_Union[QueueInfo, _Mapping]]] = ..., cluster_node_count: _Optional[int] = ...) -> None: ... + +class CreateApiKeyRequest(_message.Message): + __slots__ = ("name", "expires_at_ms", "is_superadmin") + NAME_FIELD_NUMBER: _ClassVar[int] + EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] + IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] + name: str + expires_at_ms: int + is_superadmin: bool + def __init__(self, name: _Optional[str] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... + +class CreateApiKeyResponse(_message.Message): + __slots__ = ("key_id", "key", "is_superadmin") + KEY_ID_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] + key_id: str + key: str + is_superadmin: bool + def __init__(self, key_id: _Optional[str] = ..., key: _Optional[str] = ..., is_superadmin: bool = ...) -> None: ... + +class RevokeApiKeyRequest(_message.Message): + __slots__ = ("key_id",) + KEY_ID_FIELD_NUMBER: _ClassVar[int] + key_id: str + def __init__(self, key_id: _Optional[str] = ...) -> None: ... + +class RevokeApiKeyResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ListApiKeysRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ApiKeyInfo(_message.Message): + __slots__ = ("key_id", "name", "created_at_ms", "expires_at_ms", "is_superadmin") + KEY_ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + CREATED_AT_MS_FIELD_NUMBER: _ClassVar[int] + EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] + IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] + key_id: str + name: str + created_at_ms: int + expires_at_ms: int + is_superadmin: bool + def __init__(self, key_id: _Optional[str] = ..., name: _Optional[str] = ..., created_at_ms: _Optional[int] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... + +class ListApiKeysResponse(_message.Message): + __slots__ = ("keys",) + KEYS_FIELD_NUMBER: _ClassVar[int] + keys: _containers.RepeatedCompositeFieldContainer[ApiKeyInfo] + def __init__(self, keys: _Optional[_Iterable[_Union[ApiKeyInfo, _Mapping]]] = ...) -> None: ... + +class AclPermission(_message.Message): + __slots__ = ("kind", "pattern") + KIND_FIELD_NUMBER: _ClassVar[int] + PATTERN_FIELD_NUMBER: _ClassVar[int] + kind: str + pattern: str + def __init__(self, kind: _Optional[str] = ..., pattern: _Optional[str] = ...) -> None: ... + +class SetAclRequest(_message.Message): + __slots__ = ("key_id", "permissions") + KEY_ID_FIELD_NUMBER: _ClassVar[int] + PERMISSIONS_FIELD_NUMBER: _ClassVar[int] + key_id: str + permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] + def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ...) -> None: ... + +class SetAclResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetAclRequest(_message.Message): + __slots__ = ("key_id",) + KEY_ID_FIELD_NUMBER: _ClassVar[int] + key_id: str + def __init__(self, key_id: _Optional[str] = ...) -> None: ... + +class GetAclResponse(_message.Message): + __slots__ = ("key_id", "permissions", "is_superadmin") + KEY_ID_FIELD_NUMBER: _ClassVar[int] + PERMISSIONS_FIELD_NUMBER: _ClassVar[int] + IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] + key_id: str + permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] + is_superadmin: bool + def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ..., is_superadmin: bool = ...) -> None: ... diff --git a/fila/v1/admin_pb2_grpc.py b/fila/v1/admin_pb2_grpc.py index 3b07e1a..93d6c4e 100644 --- a/fila/v1/admin_pb2_grpc.py +++ b/fila/v1/admin_pb2_grpc.py @@ -75,6 +75,31 @@ def __init__(self, channel): request_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, response_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, _registered_method=True) + self.CreateApiKey = channel.unary_unary( + '/fila.v1.FilaAdmin/CreateApiKey', + request_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, + _registered_method=True) + self.RevokeApiKey = channel.unary_unary( + '/fila.v1.FilaAdmin/RevokeApiKey', + request_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, + _registered_method=True) + self.ListApiKeys = channel.unary_unary( + '/fila.v1.FilaAdmin/ListApiKeys', + request_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, + _registered_method=True) + self.SetAcl = channel.unary_unary( + '/fila.v1.FilaAdmin/SetAcl', + request_serializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, + _registered_method=True) + self.GetAcl = channel.unary_unary( + '/fila.v1.FilaAdmin/GetAcl', + request_serializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, + _registered_method=True) class FilaAdminServicer(object): @@ -129,6 +154,38 @@ def ListQueues(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def CreateApiKey(self, request, context): + """API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RevokeApiKey(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListApiKeys(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SetAcl(self, request, context): + """Per-key ACL management. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetAcl(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_FilaAdminServicer_to_server(servicer, server): rpc_method_handlers = { @@ -172,6 +229,31 @@ def add_FilaAdminServicer_to_server(servicer, server): request_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.FromString, response_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.SerializeToString, ), + 'CreateApiKey': grpc.unary_unary_rpc_method_handler( + servicer.CreateApiKey, + request_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.FromString, + response_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.SerializeToString, + ), + 'RevokeApiKey': grpc.unary_unary_rpc_method_handler( + servicer.RevokeApiKey, + request_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.FromString, + response_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.SerializeToString, + ), + 'ListApiKeys': grpc.unary_unary_rpc_method_handler( + servicer.ListApiKeys, + request_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.FromString, + response_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.SerializeToString, + ), + 'SetAcl': grpc.unary_unary_rpc_method_handler( + servicer.SetAcl, + request_deserializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.FromString, + response_serializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.SerializeToString, + ), + 'GetAcl': grpc.unary_unary_rpc_method_handler( + servicer.GetAcl, + request_deserializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.FromString, + response_serializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'fila.v1.FilaAdmin', rpc_method_handlers) @@ -399,3 +481,138 @@ def ListQueues(request, timeout, metadata, _registered_method=True) + + @staticmethod + def CreateApiKey(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaAdmin/CreateApiKey', + fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, + fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RevokeApiKey(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaAdmin/RevokeApiKey', + fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, + fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ListApiKeys(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaAdmin/ListApiKeys', + fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, + fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SetAcl(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaAdmin/SetAcl', + fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, + fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetAcl(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaAdmin/GetAcl', + fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, + fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/proto/fila/v1/admin.proto b/proto/fila/v1/admin.proto index bf9d5ca..886e58d 100644 --- a/proto/fila/v1/admin.proto +++ b/proto/fila/v1/admin.proto @@ -11,6 +11,15 @@ service FilaAdmin { rpc GetStats(GetStatsRequest) returns (GetStatsResponse); rpc Redrive(RedriveRequest) returns (RedriveResponse); rpc ListQueues(ListQueuesRequest) returns (ListQueuesResponse); + + // API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. + rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse); + rpc RevokeApiKey(RevokeApiKeyRequest) returns (RevokeApiKeyResponse); + rpc ListApiKeys(ListApiKeysRequest) returns (ListApiKeysResponse); + + // Per-key ACL management. + rpc SetAcl(SetAclRequest) returns (SetAclResponse); + rpc GetAcl(GetAclRequest) returns (GetAclResponse); } message CreateQueueRequest { @@ -89,6 +98,9 @@ message GetStatsResponse { uint32 quantum = 5; repeated PerFairnessKeyStats per_key_stats = 6; repeated PerThrottleKeyStats per_throttle_stats = 7; + // Cluster fields (0 when not in cluster mode). + uint64 leader_node_id = 8; + uint32 replication_count = 9; } message RedriveRequest { @@ -107,8 +119,79 @@ message QueueInfo { uint64 depth = 2; uint64 in_flight = 3; uint32 active_consumers = 4; + uint64 leader_node_id = 5; } message ListQueuesResponse { repeated QueueInfo queues = 1; + uint32 cluster_node_count = 2; +} + +// --- API Key Management --- + +message CreateApiKeyRequest { + /// Human-readable label for the key. + string name = 1; + /// Optional Unix timestamp (milliseconds) after which the key expires. + /// 0 means no expiration. + uint64 expires_at_ms = 2; + /// When true, the key bypasses all ACL checks (superadmin). + bool is_superadmin = 3; +} + +message CreateApiKeyResponse { + /// Opaque key ID for management operations (revoke, list, set-acl). + string key_id = 1; + /// Plaintext API key. Returned once — store it securely. + string key = 2; + /// Whether this key has superadmin privileges. + bool is_superadmin = 3; +} + +message RevokeApiKeyRequest { + string key_id = 1; +} + +message RevokeApiKeyResponse {} + +message ListApiKeysRequest {} + +message ApiKeyInfo { + string key_id = 1; + string name = 2; + uint64 created_at_ms = 3; + /// 0 means no expiration. + uint64 expires_at_ms = 4; + bool is_superadmin = 5; +} + +message ListApiKeysResponse { + repeated ApiKeyInfo keys = 1; +} + +// --- ACL Management --- + +/// A single permission grant: kind (produce/consume/admin) + queue pattern. +message AclPermission { + /// One of: "produce", "consume", "admin". + string kind = 1; + /// Queue name or wildcard ("*" or "orders.*"). + string pattern = 2; +} + +message SetAclRequest { + string key_id = 1; + repeated AclPermission permissions = 2; +} + +message SetAclResponse {} + +message GetAclRequest { + string key_id = 1; +} + +message GetAclResponse { + string key_id = 1; + repeated AclPermission permissions = 2; + bool is_superadmin = 3; } diff --git a/tests/conftest.py b/tests/conftest.py index 84e98da..3b91d60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from __future__ import annotations +import ipaddress import os import shutil import socket @@ -37,13 +38,130 @@ def _find_free_port() -> int: return s.getsockname()[1] +def _generate_self_signed_certs(out_dir: str) -> dict[str, str]: + """Generate a self-signed CA + server + client cert for testing. + + Returns a dict with keys: ca_cert, server_cert, server_key, client_cert, client_key. + """ + import datetime + + try: + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError: + pytest.skip("cryptography package required for TLS tests") + + # CA key + cert + ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Fila Test CA")]) + ca_cert = ( + x509.CertificateBuilder() + .subject_name(ca_name) + .issuer_name(ca_name) + .public_key(ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) + ) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .sign(ca_key, hashes.SHA256()) + ) + + # Server key + cert + server_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + server_cert = ( + x509.CertificateBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")])) + .issuer_name(ca_name) + .public_key(server_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) + ) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(ca_key, hashes.SHA256()) + ) + + # Client key + cert + client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + client_cert = ( + x509.CertificateBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test-client")])) + .issuer_name(ca_name) + .public_key(client_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) + ) + .sign(ca_key, hashes.SHA256()) + ) + + def _write_pem(path: str, data: bytes) -> str: + with open(path, "wb") as f: + f.write(data) + return path + + paths = { + "ca_cert": _write_pem( + os.path.join(out_dir, "ca.pem"), + ca_cert.public_bytes(serialization.Encoding.PEM), + ), + "server_cert": _write_pem( + os.path.join(out_dir, "server.pem"), + server_cert.public_bytes(serialization.Encoding.PEM), + ), + "server_key": _write_pem( + os.path.join(out_dir, "server-key.pem"), + server_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ), + ), + "client_cert": _write_pem( + os.path.join(out_dir, "client.pem"), + client_cert.public_bytes(serialization.Encoding.PEM), + ), + "client_key": _write_pem( + os.path.join(out_dir, "client-key.pem"), + client_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ), + ), + } + return paths + + class TestServer: """Manages a fila-server subprocess for integration tests.""" - def __init__(self, addr: str, process: subprocess.Popen[bytes], data_dir: str) -> None: + def __init__( + self, + addr: str, + process: subprocess.Popen[bytes], + data_dir: str, + *, + tls_paths: dict[str, str] | None = None, + api_key: str | None = None, + ) -> None: self.addr = addr self._process = process self._data_dir = data_dir + self.tls_paths = tls_paths + self.api_key = api_key def stop(self) -> None: """Kill the server and clean up.""" @@ -51,9 +169,33 @@ def stop(self) -> None: self._process.wait() shutil.rmtree(self._data_dir, ignore_errors=True) + def _make_channel(self) -> grpc.Channel: + """Create a gRPC channel to this server (TLS-aware).""" + if self.tls_paths is not None: + with open(self.tls_paths["ca_cert"], "rb") as f: + ca = f.read() + with open(self.tls_paths["client_cert"], "rb") as f: + cert = f.read() + with open(self.tls_paths["client_key"], "rb") as f: + key = f.read() + creds = grpc.ssl_channel_credentials( + root_certificates=ca, + private_key=key, + certificate_chain=cert, + ) + channel = grpc.secure_channel(self.addr, creds) + else: + channel = grpc.insecure_channel(self.addr) + + if self.api_key is not None: + from fila.client import _ApiKeyInterceptor + channel = grpc.intercept_channel(channel, _ApiKeyInterceptor(self.api_key)) + + return channel + def create_queue(self, name: str) -> None: """Create a queue on the test server via admin gRPC.""" - channel = grpc.insecure_channel(self.addr) + channel = self._make_channel() stub = admin_pb2_grpc.FilaAdminStub(channel) stub.CreateQueue( admin_pb2.CreateQueueRequest( @@ -94,13 +236,14 @@ def server() -> Generator[TestServer, None, None]: # Wait for server to be ready. deadline = time.monotonic() + 10.0 while time.monotonic() < deadline: + channel = grpc.insecure_channel(addr) try: - channel = grpc.insecure_channel(addr) stub = admin_pb2_grpc.FilaAdminStub(channel) stub.ListQueues(admin_pb2.ListQueuesRequest()) channel.close() break except grpc.RpcError: + channel.close() time.sleep(0.05) else: ts.stop() @@ -109,3 +252,118 @@ def server() -> Generator[TestServer, None, None]: yield ts ts.stop() + + +@pytest.fixture() +def tls_server() -> Generator[TestServer, None, None]: + """Start a TLS-enabled fila-server, yield it, then shut down.""" + if not FILA_SERVER_AVAILABLE: + pytest.skip(f"fila-server binary not found at {FILA_SERVER_BIN}") + + try: + import cryptography # noqa: F401 + except ImportError: + pytest.skip("cryptography package required for TLS tests") + + port = _find_free_port() + addr = f"127.0.0.1:{port}" + + data_dir = tempfile.mkdtemp(prefix="fila-tls-test-") + tls_paths = _generate_self_signed_certs(data_dir) + + # Write config with TLS enabled. + config_path = os.path.join(data_dir, "fila.toml") + with open(config_path, "w") as f: + f.write( + f'[server]\n' + f'listen_addr = "{addr}"\n' + f'\n' + f'[tls]\n' + f'ca_cert = "{tls_paths["ca_cert"]}"\n' + f'server_cert = "{tls_paths["server_cert"]}"\n' + f'server_key = "{tls_paths["server_key"]}"\n' + ) + + env = {**os.environ, "FILA_DATA_DIR": os.path.join(data_dir, "db")} + process = subprocess.Popen( + [FILA_SERVER_BIN], + cwd=data_dir, + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + ts = TestServer(addr, process, data_dir, tls_paths=tls_paths) + + # Wait for server to be ready (use TLS channel). + deadline = time.monotonic() + 10.0 + while time.monotonic() < deadline: + channel = ts._make_channel() + try: + stub = admin_pb2_grpc.FilaAdminStub(channel) + stub.ListQueues(admin_pb2.ListQueuesRequest()) + channel.close() + break + except grpc.RpcError: + channel.close() + time.sleep(0.05) + else: + ts.stop() + pytest.fail("TLS fila-server did not become ready within 10s") + + yield ts + + ts.stop() + + +@pytest.fixture() +def auth_server() -> Generator[TestServer, None, None]: + """Start a fila-server with API key auth enabled, yield it, then shut down.""" + if not FILA_SERVER_AVAILABLE: + pytest.skip(f"fila-server binary not found at {FILA_SERVER_BIN}") + + port = _find_free_port() + addr = f"127.0.0.1:{port}" + bootstrap_key = "test-bootstrap-key-for-integration" + + data_dir = tempfile.mkdtemp(prefix="fila-auth-test-") + + # Write config with bootstrap API key. + config_path = os.path.join(data_dir, "fila.toml") + with open(config_path, "w") as f: + f.write( + f'[server]\n' + f'listen_addr = "{addr}"\n' + f'bootstrap_apikey = "{bootstrap_key}"\n' + ) + + env = {**os.environ, "FILA_DATA_DIR": os.path.join(data_dir, "db")} + process = subprocess.Popen( + [FILA_SERVER_BIN], + cwd=data_dir, + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + ts = TestServer(addr, process, data_dir, api_key=bootstrap_key) + + # Wait for server to be ready. + deadline = time.monotonic() + 10.0 + while time.monotonic() < deadline: + channel = ts._make_channel() + try: + stub = admin_pb2_grpc.FilaAdminStub(channel) + stub.ListQueues(admin_pb2.ListQueuesRequest()) + channel.close() + break + except grpc.RpcError: + channel.close() + time.sleep(0.05) + else: + ts.stop() + pytest.fail("auth fila-server did not become ready within 10s") + + yield ts + + ts.stop() diff --git a/tests/test_client.py b/tests/test_client.py index 6096e73..b8e353e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -102,3 +102,143 @@ async def test_async_enqueue_consume_ack(self, server: object) -> None: # Ack the message. await client.ack("test-async-eca", msg.id) + + +class TestTlsClient: + """Integration tests for TLS connections.""" + + def test_tls_enqueue_consume_ack(self, tls_server: object) -> None: + """Full lifecycle over TLS: enqueue -> consume -> ack.""" + from tests.conftest import TestServer + + assert isinstance(tls_server, TestServer) + assert tls_server.tls_paths is not None + + tls_server.create_queue("test-tls") + + with open(tls_server.tls_paths["ca_cert"], "rb") as f: + ca_cert = f.read() + with open(tls_server.tls_paths["client_cert"], "rb") as f: + client_cert = f.read() + with open(tls_server.tls_paths["client_key"], "rb") as f: + client_key = f.read() + + with fila.Client( + tls_server.addr, + ca_cert=ca_cert, + client_cert=client_cert, + client_key=client_key, + ) as client: + msg_id = client.enqueue("test-tls", {"secure": "true"}, b"tls payload") + assert msg_id != "" + + stream = client.consume("test-tls") + msg = next(stream) + + assert msg.id == msg_id + assert msg.payload == b"tls payload" + + client.ack("test-tls", msg.id) + + @pytest.mark.asyncio + async def test_async_tls_enqueue_consume_ack(self, tls_server: object) -> None: + """Full async lifecycle over TLS.""" + from tests.conftest import TestServer + + assert isinstance(tls_server, TestServer) + assert tls_server.tls_paths is not None + + tls_server.create_queue("test-async-tls") + + with open(tls_server.tls_paths["ca_cert"], "rb") as f: + ca_cert = f.read() + with open(tls_server.tls_paths["client_cert"], "rb") as f: + client_cert = f.read() + with open(tls_server.tls_paths["client_key"], "rb") as f: + client_key = f.read() + + async with fila.AsyncClient( + tls_server.addr, + ca_cert=ca_cert, + client_cert=client_cert, + client_key=client_key, + ) as client: + msg_id = await client.enqueue("test-async-tls", None, b"async tls") + assert msg_id != "" + + stream = await client.consume("test-async-tls") + msg = await stream.__anext__() + + assert msg.id == msg_id + assert msg.payload == b"async tls" + + await client.ack("test-async-tls", msg.id) + + +class TestApiKeyAuth: + """Integration tests for API key authentication.""" + + def test_api_key_enqueue_consume_ack(self, auth_server: object) -> None: + """Full lifecycle with API key auth: enqueue -> consume -> ack.""" + from tests.conftest import TestServer + + assert isinstance(auth_server, TestServer) + assert auth_server.api_key is not None + + auth_server.create_queue("test-auth") + + with fila.Client(auth_server.addr, api_key=auth_server.api_key) as client: + msg_id = client.enqueue("test-auth", None, b"authenticated") + assert msg_id != "" + + stream = client.consume("test-auth") + msg = next(stream) + + assert msg.id == msg_id + assert msg.payload == b"authenticated" + + client.ack("test-auth", msg.id) + + def test_missing_api_key_rejected(self, auth_server: object) -> None: + """Requests without API key are rejected when auth is enabled.""" + import grpc + + from tests.conftest import TestServer + + assert isinstance(auth_server, TestServer) + + # Probe whether the server actually enforces API key auth. + # The dev-latest binary may predate the bootstrap_apikey feature, + # in which case unauthenticated requests succeed rather than fail. + with fila.Client(auth_server.addr) as probe: + try: + probe.enqueue("__auth_probe__", None, b"probe") + except fila.RPCError as e: + if e.code != grpc.StatusCode.UNAUTHENTICATED: + pytest.fail(f"unexpected RPC error during auth probe: {e.code}") + except fila.QueueNotFoundError: + pytest.skip("server does not enforce API key auth") + else: + pytest.skip("server does not enforce API key auth") + + # If we reach here, the server enforces auth. + with fila.Client(auth_server.addr) as client: + with pytest.raises(fila.RPCError) as exc_info: + client.enqueue("test-auth", None, b"no-key") + assert exc_info.value.code == grpc.StatusCode.UNAUTHENTICATED + + @pytest.mark.asyncio + async def test_async_api_key_enqueue(self, auth_server: object) -> None: + """Async client with API key can enqueue successfully.""" + from tests.conftest import TestServer + + assert isinstance(auth_server, TestServer) + assert auth_server.api_key is not None + + auth_server.create_queue("test-async-auth") + + async with fila.AsyncClient( + auth_server.addr, api_key=auth_server.api_key + ) as client: + msg_id = await client.enqueue("test-async-auth", None, b"async auth") + assert msg_id != ""