From ac2c4a2b8ca7058e088f9547e5efe85c5ccd0496 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Thu, 15 May 2025 10:51:53 +0900 Subject: [PATCH 1/3] all: add calloptions and handlercontext --- src/connect/call_options.py | 11 +++++++ src/connect/client.py | 60 ++++++++++++++++++++++------------ src/connect/connect.py | 24 +++++++------- src/connect/handler.py | 48 +++++++++++++-------------- src/connect/handler_context.py | 11 +++++++ src/connect/interceptor.py | 5 +-- 6 files changed, 99 insertions(+), 60 deletions(-) create mode 100644 src/connect/call_options.py create mode 100644 src/connect/handler_context.py diff --git a/src/connect/call_options.py b/src/connect/call_options.py new file mode 100644 index 0000000..09f4bd9 --- /dev/null +++ b/src/connect/call_options.py @@ -0,0 +1,11 @@ +import asyncio + +from pydantic import BaseModel, Field + + +class CallOptions(BaseModel): + timeout: float | None = Field(default=None) + """Timeout for the call in seconds.""" + + abort_event: asyncio.Event | None = Field(default=None) + """Event to abort the call.""" diff --git a/src/connect/client.py b/src/connect/client.py index 8c0329b..a14bfd4 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -10,6 +10,7 @@ import httpcore from yarl import URL +from connect.call_options import CallOptions from connect.code import Code from connect.codec import Codec, CodecNameType, ProtoBinaryCodec, ProtoJSONCodec from connect.compression import COMPRESSION_IDENTITY, Compression, GZipCompression, get_compresion_from_name @@ -181,8 +182,10 @@ class Client[T_Request, T_Response]: config: ClientConfig protocol_client: ProtocolClient - _call_unary: Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]] - _call_stream: Callable[[StreamType, StreamRequest[T_Request]], Awaitable[StreamResponse[T_Response]]] + _call_unary: Callable[[UnaryRequest[T_Request], CallOptions | None], Awaitable[UnaryResponse[T_Response]]] + _call_stream: Callable[ + [StreamType, StreamRequest[T_Request], CallOptions | None], Awaitable[StreamResponse[T_Response]] + ] def __init__( self, @@ -227,7 +230,7 @@ def __init__( unary_spec = config.spec(StreamType.Unary) - async def _unary_func(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: + async def _unary_func(request: UnaryRequest[T_Request], call_options: CallOptions) -> UnaryResponse[T_Response]: conn = protocol_client.conn(unary_spec, request.headers) def on_request_send(r: httpcore.Request) -> None: @@ -239,25 +242,29 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(aiterate([request.message]), request.timeout, abort_event=request.abort_event) + await conn.send(aiterate([request.message]), call_options.timeout, abort_event=call_options.abort_event) - response = await recieve_unary_response(conn=conn, t=output, abort_event=request.abort_event) + response = await recieve_unary_response(conn=conn, t=output, abort_event=call_options.abort_event) return response unary_func = apply_interceptors(_unary_func, options.interceptors) - async def call_unary(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: + async def call_unary( + request: UnaryRequest[T_Request], call_options: CallOptions | None + ) -> UnaryResponse[T_Response]: request.spec = unary_spec request.peer = protocol_client.peer protocol_client.write_request_headers(StreamType.Unary, request.headers) + call_options = call_options or CallOptions() + if not isinstance(request.message, input): raise ConnectError( f"expected request of type: {input.__name__}", Code.INTERNAL, ) - response = await unary_func(request) + response = await unary_func(request, call_options) if not isinstance(response.message, output): raise ConnectError( @@ -267,7 +274,9 @@ async def call_unary(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Respon return response - async def _stream_func(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: + async def _stream_func( + request: StreamRequest[T_Request], call_options: CallOptions + ) -> StreamResponse[T_Response]: conn = protocol_client.conn(request.spec, request.headers) def on_request_send(r: httpcore.Request) -> None: @@ -279,27 +288,30 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.messages, request.timeout, request.abort_event) + await conn.send(request.messages, call_options.timeout, call_options.abort_event) - response = await recieve_stream_response(conn, output, request.spec, request.abort_event) + response = await recieve_stream_response(conn, output, request.spec, call_options.abort_event) return response stream_func = apply_interceptors(_stream_func, options.interceptors) async def call_stream( - stream_type: StreamType, - request: StreamRequest[T_Request], + stream_type: StreamType, request: StreamRequest[T_Request], call_options: CallOptions | None ) -> StreamResponse[T_Response]: request.spec = config.spec(stream_type) request.peer = protocol_client.peer protocol_client.write_request_headers(stream_type, request.headers) - return await stream_func(request) + call_options = call_options or CallOptions() + + return await stream_func(request, call_options) self._call_unary = call_unary self._call_stream = call_stream - async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: + async def call_unary( + self, request: UnaryRequest[T_Request], call_options: CallOptions | None + ) -> UnaryResponse[T_Response]: """Asynchronously calls a unary RPC (Remote Procedure Call) with the given request. Args: @@ -309,10 +321,12 @@ async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_ UnaryResponse[T_Response]: The response object containing the data received from the server. """ - return await self._call_unary(request) + return await self._call_unary(request, call_options) @contextlib.asynccontextmanager - async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + async def call_server_stream( + self, request: StreamRequest[T_Request], call_options: CallOptions | None = None + ) -> AsyncGenerator[StreamResponse[T_Response]]: """Initiate a server-streaming RPC call and returns an asynchronous generator that yields responses from the server. Args: @@ -332,14 +346,16 @@ async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGe request and response types, respectively. """ - response = await self._call_stream(StreamType.ServerStream, request) + response = await self._call_stream(StreamType.ServerStream, request, call_options) try: yield response finally: await response.aclose() @contextlib.asynccontextmanager - async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + async def call_client_stream( + self, request: StreamRequest[T_Request], call_options: CallOptions | None = None + ) -> AsyncGenerator[StreamResponse[T_Response]]: """Initiate a client-streaming RPC call and returns an asynchronous generator for streaming responses from the server. Args: @@ -358,14 +374,16 @@ async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGe ensure proper cleanup of the response stream. """ - response = await self._call_stream(StreamType.ClientStream, request) + response = await self._call_stream(StreamType.ClientStream, request, call_options) try: yield response finally: await response.aclose() @contextlib.asynccontextmanager - async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + async def call_bidi_stream( + self, request: StreamRequest[T_Request], call_options: CallOptions | None = None + ) -> AsyncGenerator[StreamResponse[T_Response]]: """Initiate a bidirectional streaming call with the server. This method sends a stream request to the server and returns an asynchronous @@ -387,7 +405,7 @@ async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGene connection is closed in the `finally` block. """ - response = await self._call_stream(StreamType.BiDiStream, request) + response = await self._call_stream(StreamType.BiDiStream, request, call_options) try: yield response finally: diff --git a/src/connect/connect.py b/src/connect/connect.py index 81b36fc..3232a06 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -148,8 +148,8 @@ class StreamRequest[T](RequestCommon): """ _messages: AsyncIterable[T] - timeout: float | None - abort_event: asyncio.Event | None = None + # timeout: float | None + # abort_event: asyncio.Event | None = None def __init__( self, @@ -158,8 +158,8 @@ def __init__( peer: Peer | None = None, headers: Headers | None = None, method: str | None = None, - timeout: float | None = None, - abort_event: asyncio.Event | None = None, + # timeout: float | None = None, + # abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. @@ -178,8 +178,8 @@ def __init__( """ super().__init__(spec, peer, headers, method) self._messages = content if isinstance(content, AsyncIterable) else aiterate([content]) - self.timeout = timeout - self.abort_event = abort_event + # self.timeout = timeout + # self.abort_event = abort_event @property def messages(self) -> AsyncIterable[T]: @@ -200,8 +200,8 @@ class UnaryRequest[T](RequestCommon): """ _message: T - timeout: float | None - abort_event: asyncio.Event | None = None + # timeout: float | None + # abort_event: asyncio.Event | None = None def __init__( self, @@ -210,8 +210,8 @@ def __init__( peer: Peer | None = None, headers: Headers | None = None, method: str | None = None, - timeout: float | None = None, - abort_event: asyncio.Event | None = None, + # timeout: float | None = None, + # abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. @@ -230,8 +230,8 @@ def __init__( """ super().__init__(spec, peer, headers, method) self._message = content - self.timeout = timeout - self.abort_event = abort_event + # self.timeout = timeout + # self.abort_event = abort_event @property def message(self) -> T: diff --git a/src/connect/handler.py b/src/connect/handler.py index bb3be8a..45ad7cd 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -23,6 +23,7 @@ receive_unary_request, ) from connect.error import ConnectError +from connect.handler_context import HandlerContext from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel from connect.interceptor import apply_interceptors @@ -45,8 +46,12 @@ from connect.response_writer import ServerResponseWriter from connect.utils import aiterate -type UnaryFunc[T_Request, T_Response] = Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]] -type StreamFunc[T_Request, T_Response] = Callable[[StreamRequest[T_Request]], Awaitable[StreamResponse[T_Response]]] +type UnaryFunc[T_Request, T_Response] = Callable[ + [UnaryRequest[T_Request], HandlerContext], Awaitable[UnaryResponse[T_Response]] +] +type StreamFunc[T_Request, T_Response] = Callable[ + [StreamRequest[T_Request], HandlerContext], Awaitable[StreamResponse[T_Response]] +] class HandlerConfig: @@ -386,8 +391,8 @@ def __init__( config = HandlerConfig(procedure=procedure, stream_type=StreamType.Unary, options=options) protocol_handlers = create_protocol_handlers(config) - async def _call(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: - response = await unary(request) + async def _call(request: UnaryRequest[T_Request], context: HandlerContext) -> UnaryResponse[T_Response]: + response = await unary(request, context) return response @@ -421,10 +426,8 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None """ request = await receive_unary_request(conn, self.input) - if timeout: - request.timeout = timeout - - response = await self.call(request) + context = HandlerContext(timeout=timeout) + response = await self.call(request, context) conn.response_headers.update(exclude_protocol_headers(response.headers)) conn.response_trailers.update(exclude_protocol_headers(response.trailers)) @@ -476,8 +479,8 @@ def __init__( config = HandlerConfig(procedure=procedure, stream_type=StreamType.ServerStream, options=options) protocol_handlers = create_protocol_handlers(config) - async def _call(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - response = await stream(request) + async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> StreamResponse[T_Response]: + response = await stream(request, context) return response call = apply_interceptors(_call, options.interceptors) @@ -512,10 +515,8 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None """ request = await receive_stream_request(conn, self.input) - if timeout: - request.timeout = timeout - - response = await self.call(request) + context = HandlerContext(timeout=timeout) + response = await self.call(request, context) conn.response_headers.update(response.headers) conn.response_trailers.update(response.trailers) @@ -569,8 +570,8 @@ def __init__( config = HandlerConfig(procedure=procedure, stream_type=StreamType.ClientStream, options=options) protocol_handlers = create_protocol_handlers(config) - async def _call(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - response = await stream(request) + async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> StreamResponse[T_Response]: + response = await stream(request, context) return response call = apply_interceptors(_call, options.interceptors) @@ -605,10 +606,9 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None """ request = await receive_stream_request(conn, self.input) - if timeout: - request.timeout = timeout + context = HandlerContext(timeout=timeout) - response = await self.call(request) + response = await self.call(request, context) conn.response_headers.update(response.headers) conn.response_trailers.update(response.trailers) @@ -663,8 +663,8 @@ def __init__( config = HandlerConfig(procedure=procedure, stream_type=StreamType.BiDiStream, options=options) protocol_handlers = create_protocol_handlers(config) - async def _call(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - response = await stream(request) + async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> StreamResponse[T_Response]: + response = await stream(request, context) return response call = apply_interceptors(_call, options.interceptors) @@ -699,10 +699,8 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None """ request = await receive_stream_request(conn, self.input) - if timeout: - request.timeout = timeout - - response = await self.call(request) + context = HandlerContext(timeout=timeout) + response = await self.call(request, context) conn.response_headers.update(response.headers) conn.response_trailers.update(response.trailers) diff --git a/src/connect/handler_context.py b/src/connect/handler_context.py new file mode 100644 index 0000000..f01e41c --- /dev/null +++ b/src/connect/handler_context.py @@ -0,0 +1,11 @@ +class HandlerContext: + timeout: float | None + + def __init__(self, timeout: float | None) -> None: + self.timeout = timeout + + def timeout_remaining(self) -> float: + if self.timeout is None: + return 0 + + return self.timeout diff --git a/src/connect/interceptor.py b/src/connect/interceptor.py index 2bc8ab5..1b09743 100644 --- a/src/connect/interceptor.py +++ b/src/connect/interceptor.py @@ -4,10 +4,11 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeGuard, overload +from connect.call_options import CallOptions from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse -UnaryFunc = Callable[[UnaryRequest[Any]], Awaitable[UnaryResponse[Any]]] -StreamFunc = Callable[[StreamRequest[Any]], Awaitable[StreamResponse[Any]]] +UnaryFunc = Callable[[UnaryRequest[Any], CallOptions], Awaitable[UnaryResponse[Any]]] +StreamFunc = Callable[[StreamRequest[Any], CallOptions], Awaitable[StreamResponse[Any]]] class Interceptor: From 5e234c9fcbe708c77ef9e80824cec9e14d9c3a2b Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 16 May 2025 19:23:10 +0900 Subject: [PATCH 2/3] all: add call options and handler context --- conformance/client_runner.py | 21 +++- .../conformancev1connect/service_connect.py | 13 ++- conformance/server.py | 32 +++-- examples/server.py | 9 +- src/connect/call_options.py | 4 +- src/connect/client.py | 4 +- .../{interceptor.py => client_interceptor.py} | 14 ++- src/connect/handler.py | 5 +- src/connect/handler_context.py | 21 +++- src/connect/handler_interceptor.py | 109 ++++++++++++++++++ src/connect/options.py | 7 +- tests/test_streaming_connect_client.py | 27 ++--- tests/test_streaming_connect_server.py | 59 ++++++---- tests/test_unary_connect_client.py | 15 +-- tests/test_unary_connect_server.py | 29 +++-- .../ping/v1/v1connect/ping_connect.py | 7 +- 16 files changed, 277 insertions(+), 99 deletions(-) rename src/connect/{interceptor.py => client_interceptor.py} (88%) create mode 100644 src/connect/handler_interceptor.py diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 0e37df4..3b43719 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -10,6 +10,7 @@ from collections.abc import AsyncGenerator from typing import Any +from connect.call_options import CallOptions from connect.connect import StreamRequest, UnaryRequest from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError @@ -253,6 +254,8 @@ async def delayed_abort() -> None: UnaryRequest( content=req, headers=headers, + ), + CallOptions( timeout=msg.timeout_ms / 1000, abort_event=abort_event, ), @@ -290,8 +293,10 @@ async def delayed_abort() -> None: asyncio.create_task(delayed_abort()) async with getattr(client, msg.method)( - StreamRequest( - content=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event + StreamRequest(content=_reqs(), headers=headers), + CallOptions( + timeout=msg.timeout_ms / 1000, + abort_event=abort_event, ), ) as resp: async for message in resp.messages: @@ -314,8 +319,10 @@ async def delayed_abort() -> None: headers = to_connect_headers(msg.request_headers) async with getattr(client, msg.method)( - StreamRequest( - content=reqs, headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event + StreamRequest(content=reqs, headers=headers), + CallOptions( + timeout=msg.timeout_ms / 1000, + abort_event=abort_event, ), ) as resp: if msg.cancel.HasField("after_close_send_ms"): @@ -356,8 +363,10 @@ async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: headers = to_connect_headers(msg.request_headers) async with getattr(client, msg.method)( - StreamRequest( - content=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event + StreamRequest(content=_reqs(), headers=headers), + CallOptions( + timeout=msg.timeout_ms / 1000, + abort_event=abort_event, ), ) as resp: if msg.cancel.HasField("before_close_send"): diff --git a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py index ee2a8ec..c54e429 100644 --- a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py +++ b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py @@ -9,6 +9,7 @@ from connect.client import Client import connect.connect from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler, BidiStreamHandler +from connect.handler_context import HandlerContext from connect.options import ClientOptions, ConnectOptions from connect.connection_pool import AsyncConnectionPool from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor @@ -66,22 +67,22 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti class ConformanceServiceHandler: """Handler for the conformanceService service.""" - async def Unary(self, request: connect.connect.UnaryRequest[UnaryRequest]) -> connect.connect.UnaryResponse[UnaryResponse]: + async def Unary(self, request: connect.connect.UnaryRequest[UnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnaryResponse]: raise NotImplementedError() - async def ServerStream(self, request: connect.connect.StreamRequest[ServerStreamRequest]) -> connect.connect.StreamResponse[ServerStreamResponse]: + async def ServerStream(self, request: connect.connect.StreamRequest[ServerStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ServerStreamResponse]: raise NotImplementedError() - async def ClientStream(self, request: connect.connect.StreamRequest[ClientStreamRequest]) -> connect.connect.StreamResponse[ClientStreamResponse]: + async def ClientStream(self, request: connect.connect.StreamRequest[ClientStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ClientStreamResponse]: raise NotImplementedError() - async def BidiStream(self, request: connect.connect.StreamRequest[BidiStreamRequest]) -> connect.connect.StreamResponse[BidiStreamResponse]: + async def BidiStream(self, request: connect.connect.StreamRequest[BidiStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[BidiStreamResponse]: raise NotImplementedError() - async def Unimplemented(self, request: connect.connect.UnaryRequest[UnimplementedRequest]) -> connect.connect.UnaryResponse[UnimplementedResponse]: + async def Unimplemented(self, request: connect.connect.UnaryRequest[UnimplementedRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnimplementedResponse]: raise NotImplementedError() - async def IdempotentUnary(self, request: connect.connect.UnaryRequest[IdempotentUnaryRequest]) -> connect.connect.UnaryResponse[IdempotentUnaryResponse]: + async def IdempotentUnary(self, request: connect.connect.UnaryRequest[IdempotentUnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[IdempotentUnaryResponse]: raise NotImplementedError() diff --git a/conformance/server.py b/conformance/server.py index b0f2c4c..26d1f2d 100644 --- a/conformance/server.py +++ b/conformance/server.py @@ -8,6 +8,7 @@ from connect.code import Code from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse from connect.error import ConnectError, ErrorDetail +from connect.handler_context import HandlerContext from connect.headers import Headers from connect.middleware import ConnectMiddleware from starlette.applications import Starlette @@ -95,7 +96,9 @@ def code_from_pb_code(code: config_pb2.Code) -> Code: class ConformanceService(ConformanceServiceHandler): """ConformanceService is a service handler that implements various gRPC methods for testing conformance.""" - async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryResponse[service_pb2.UnaryResponse]: + async def Unary( + self, request: UnaryRequest[service_pb2.UnaryRequest], context: HandlerContext + ) -> UnaryResponse[service_pb2.UnaryResponse]: """Handle a unary gRPC request and generates a response based on the provided request definition. Args: @@ -128,10 +131,12 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR request_any = any_pb2.Any() request_any.Pack(request.message) + timeout_sec = context.timeout_remaining() + request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=[request_any], - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( query_params=pb_query_params_from_peer_query(request.peer.query), ), @@ -175,7 +180,7 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR return UnaryResponse(content=service_pb2.UnaryResponse(payload=payload), headers=headers, trailers=trailers) async def IdempotentUnary( - self, request: UnaryRequest[service_pb2.IdempotentUnaryRequest] + self, request: UnaryRequest[service_pb2.IdempotentUnaryRequest], context: HandlerContext ) -> UnaryResponse[service_pb2.IdempotentUnaryResponse]: """Handle the IdempotentUnary RPC call. @@ -202,10 +207,11 @@ async def IdempotentUnary( request_any = any_pb2.Any() request_any.Pack(request.message) + timeout_sec = context.timeout_remaining() request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=[request_any], - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( query_params=pb_query_params_from_peer_query(request.peer.query), ), @@ -251,7 +257,7 @@ async def IdempotentUnary( ) async def ClientStream( - self, request: StreamRequest[service_pb2.ClientStreamRequest] + self, request: StreamRequest[service_pb2.ClientStreamRequest], context: HandlerContext ) -> StreamResponse[service_pb2.ClientStreamResponse]: """Handle a bidirectional streaming RPC where the client sends a stream of `ClientStreamRequest` messages and receives a single `ClientStreamResponse` message. @@ -287,10 +293,11 @@ async def ClientStream( message_any.Pack(message) messages.append(message_any) + timeout_sec = context.timeout_remaining() request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( query_params=pb_query_params_from_peer_query(request.peer.query), ), @@ -339,7 +346,7 @@ async def ClientStream( ) async def ServerStream( - self, request: StreamRequest[service_pb2.ServerStreamRequest] + self, request: StreamRequest[service_pb2.ServerStreamRequest], context: HandlerContext ) -> StreamResponse[service_pb2.ServerStreamResponse]: """Handle a server-side streaming RPC call. @@ -379,10 +386,11 @@ async def ServerStream( headers = headers_from_pb_headers(response_definition.response_headers) trailers = headers_from_pb_headers(response_definition.response_trailers) + timeout_sec = context.timeout_remaining() request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( query_params=pb_query_params_from_peer_query(request.peer.query), ), @@ -442,7 +450,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]: ) async def BidiStream( - self, request: StreamRequest[service_pb2.BidiStreamRequest] + self, request: StreamRequest[service_pb2.BidiStreamRequest], context: HandlerContext ) -> StreamResponse[service_pb2.BidiStreamResponse]: """Handle a bidirectional streaming RPC. @@ -488,6 +496,8 @@ async def BidiStream( headers = headers_from_pb_headers(response_definition.response_headers) trailers = headers_from_pb_headers(response_definition.response_trailers) + timeout_sec = context.timeout_remaining() + async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: nonlocal response_index @@ -496,7 +506,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, ) else: request_info = None @@ -525,7 +535,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: request_info = service_pb2.ConformancePayload.RequestInfo( request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=int(request.timeout) if request.timeout else None, + timeout_ms=int(timeout_sec * 1000) if timeout_sec else None, ) detail = any_pb2.Any() diff --git a/examples/server.py b/examples/server.py index 2d0c08e..83f200b 100644 --- a/examples/server.py +++ b/examples/server.py @@ -5,7 +5,8 @@ import hypercorn.asyncio from connect.connect import UnaryRequest, UnaryResponse -from connect.interceptor import Interceptor, UnaryFunc +from connect.handler_context import HandlerContext +from connect.handler_interceptor import HandlerInterceptor, UnaryFunc from connect.middleware import ConnectMiddleware from starlette.applications import Starlette from starlette.middleware import Middleware @@ -23,13 +24,13 @@ async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayRespo return UnaryResponse(SayResponse(sentence=data.sentence)) -class IPRestrictionInterceptor(Interceptor): +class IPRestrictionInterceptor(HandlerInterceptor): """IP restriction interceptor.""" def wrap_unary(self, next: UnaryFunc) -> UnaryFunc: """Wrap a unary function with the interceptor.""" - async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: + async def _wrapped(request: UnaryRequest[Any], context: HandlerContext) -> UnaryResponse[Any]: ip_allow_list = os.environ.get("IP_ALLOW_LIST", "").split(",") if not ip_allow_list: raise Exception("White list not found") @@ -45,7 +46,7 @@ async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: if ip not in ip_allow_list: raise Exception("IP not allowed") - return await next(request) + return await next(request, context) return _wrapped diff --git a/src/connect/call_options.py b/src/connect/call_options.py index 09f4bd9..3a6462c 100644 --- a/src/connect/call_options.py +++ b/src/connect/call_options.py @@ -1,9 +1,11 @@ import asyncio -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class CallOptions(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + timeout: float | None = Field(default=None) """Timeout for the call in seconds.""" diff --git a/src/connect/client.py b/src/connect/client.py index a14bfd4..211b0e1 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -11,6 +11,7 @@ from yarl import URL from connect.call_options import CallOptions +from connect.client_interceptor import apply_interceptors from connect.code import Code from connect.codec import Codec, CodecNameType, ProtoBinaryCodec, ProtoJSONCodec from connect.compression import COMPRESSION_IDENTITY, Compression, GZipCompression, get_compresion_from_name @@ -27,7 +28,6 @@ from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.idempotency_level import IdempotencyLevel -from connect.interceptor import apply_interceptors from connect.options import ClientOptions from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams from connect.protocol_connect.connect_protocol import ProtocolConnect @@ -310,7 +310,7 @@ async def call_stream( self._call_stream = call_stream async def call_unary( - self, request: UnaryRequest[T_Request], call_options: CallOptions | None + self, request: UnaryRequest[T_Request], call_options: CallOptions | None = None ) -> UnaryResponse[T_Response]: """Asynchronously calls a unary RPC (Remote Procedure Call) with the given request. diff --git a/src/connect/interceptor.py b/src/connect/client_interceptor.py similarity index 88% rename from src/connect/interceptor.py rename to src/connect/client_interceptor.py index 1b09743..70f1536 100644 --- a/src/connect/interceptor.py +++ b/src/connect/client_interceptor.py @@ -11,7 +11,7 @@ StreamFunc = Callable[[StreamRequest[Any], CallOptions], Awaitable[StreamResponse[Any]]] -class Interceptor: +class ClientInterceptor: """Abstract base class for interceptors that can wrap unary functions.""" wrap_unary: Callable[[UnaryFunc], UnaryFunc] | None = None @@ -35,7 +35,7 @@ def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: parameters = list(signature.parameters.values()) return bool( callable(next) - and len(parameters) == 1 + and len(parameters) == 2 and getattr(parameters[0].annotation, "__origin__", None) is UnaryRequest ) @@ -57,20 +57,22 @@ def is_stream_func(next: UnaryFunc | StreamFunc) -> TypeGuard[StreamFunc]: parameters = list(signature.parameters.values()) return bool( callable(next) - and len(parameters) == 1 + and len(parameters) == 2 and getattr(parameters[0].annotation, "__origin__", None) is StreamRequest ) @overload -def apply_interceptors(next: UnaryFunc, interceptors: list[Interceptor] | None) -> UnaryFunc: ... +def apply_interceptors(next: UnaryFunc, interceptors: list[ClientInterceptor] | None) -> UnaryFunc: ... @overload -def apply_interceptors(next: StreamFunc, interceptors: list[Interceptor] | None) -> StreamFunc: ... +def apply_interceptors(next: StreamFunc, interceptors: list[ClientInterceptor] | None) -> StreamFunc: ... -def apply_interceptors(next: UnaryFunc | StreamFunc, interceptors: list[Interceptor] | None) -> UnaryFunc | StreamFunc: +def apply_interceptors( + next: UnaryFunc | StreamFunc, interceptors: list[ClientInterceptor] | None +) -> UnaryFunc | StreamFunc: """Apply a list of interceptors to a given function. Args: diff --git a/src/connect/handler.py b/src/connect/handler.py index 45ad7cd..d14bee6 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -24,9 +24,9 @@ ) from connect.error import ConnectError from connect.handler_context import HandlerContext +from connect.handler_interceptor import apply_interceptors from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.interceptor import apply_interceptors from connect.options import ConnectOptions from connect.protocol import ( HEADER_CONTENT_LENGTH, @@ -333,9 +333,8 @@ async def _handle( try: timeout = conn.parse_timeout() if timeout: - timeout_ms = int(timeout * 1000) with anyio.fail_after(delay=timeout): - await self.implementation(conn, timeout_ms) + await self.implementation(conn, timeout) else: await self.implementation(conn, None) diff --git a/src/connect/handler_context.py b/src/connect/handler_context.py index f01e41c..44c26b9 100644 --- a/src/connect/handler_context.py +++ b/src/connect/handler_context.py @@ -1,11 +1,20 @@ +import time + + class HandlerContext: - timeout: float | None + _deadline: float | None def __init__(self, timeout: float | None) -> None: - self.timeout = timeout + self._deadline = time.time() + timeout if timeout else None + + def timeout_remaining(self) -> float | None: + """Return the remaining time in seconds until the deadline, or None if no deadline is set. + + Returns: + float | None: The number of seconds remaining until the deadline, or None if no deadline is set. - def timeout_remaining(self) -> float: - if self.timeout is None: - return 0 + """ + if self._deadline is None: + return None - return self.timeout + return self._deadline - time.time() diff --git a/src/connect/handler_interceptor.py b/src/connect/handler_interceptor.py new file mode 100644 index 0000000..45f2d18 --- /dev/null +++ b/src/connect/handler_interceptor.py @@ -0,0 +1,109 @@ +"""Defines interceptors and request/response classes for unary and streaming RPC calls.""" + +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeGuard, overload + +from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse +from connect.handler_context import HandlerContext + +UnaryFunc = Callable[[UnaryRequest[Any], HandlerContext], Awaitable[UnaryResponse[Any]]] +StreamFunc = Callable[[StreamRequest[Any], HandlerContext], Awaitable[StreamResponse[Any]]] + + +class HandlerInterceptor: + """Abstract base class for interceptors that can wrap unary functions.""" + + wrap_unary: Callable[[UnaryFunc], UnaryFunc] | None = None + wrap_stream: Callable[[StreamFunc], StreamFunc] | None = None + + +def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: + """Determine if the given function is a unary function. + + A unary function is defined as a callable that takes a single parameter + whose type annotation has an origin of `UnaryRequest`. + + Args: + next (UnaryFunc | StreamFunc): The function to be checked. + + Returns: + TypeGuard[UnaryFunc]: True if the function is a unary function, False otherwise. + + """ + signature = inspect.signature(next) + parameters = list(signature.parameters.values()) + return bool( + callable(next) + and len(parameters) == 2 + and getattr(parameters[0].annotation, "__origin__", None) is UnaryRequest + ) + + +def is_stream_func(next: UnaryFunc | StreamFunc) -> TypeGuard[StreamFunc]: + """Determine if the given function is a StreamFunc. + + This function checks if the provided function `next` is callable, has exactly one parameter, + and if the annotation of that parameter has an origin of `StreamRequest`. + + Args: + next (UnaryFunc | StreamFunc): The function to be checked. + + Returns: + TypeGuard[StreamFunc]: True if `next` is a StreamFunc, False otherwise. + + """ + signature = inspect.signature(next) + parameters = list(signature.parameters.values()) + return bool( + callable(next) + and len(parameters) == 2 + and getattr(parameters[0].annotation, "__origin__", None) is StreamRequest + ) + + +@overload +def apply_interceptors(next: UnaryFunc, interceptors: list[HandlerInterceptor] | None) -> UnaryFunc: ... + + +@overload +def apply_interceptors(next: StreamFunc, interceptors: list[HandlerInterceptor] | None) -> StreamFunc: ... + + +def apply_interceptors( + next: UnaryFunc | StreamFunc, interceptors: list[HandlerInterceptor] | None +) -> UnaryFunc | StreamFunc: + """Apply a list of interceptors to a given function. + + Args: + next (UnaryFunc | StreamFunc): The function to which interceptors will be applied. + It can be either a unary function or a stream function. + interceptors (list[Interceptor] | None): A list of interceptors to apply. If None, the original function is returned. + + Returns: + UnaryFunc | StreamFunc: The function wrapped with the provided interceptors. + + Raises: + ValueError: If an interceptor does not implement the required wrap method for the function type, + or if the provided function type is invalid. + + """ + if interceptors is None: + return next + + _next = next + if is_unary_func(_next): + for interceptor in interceptors: + if interceptor.wrap_unary is None: + break + _next = interceptor.wrap_unary(_next) + return _next + + elif is_stream_func(_next): + for interceptor in interceptors: + if interceptor.wrap_stream is None: + break + _next = interceptor.wrap_stream(_next) + return _next + else: + raise ValueError(f"Invalid function type: {next}") diff --git a/src/connect/options.py b/src/connect/options.py index e614f21..9c8770c 100644 --- a/src/connect/options.py +++ b/src/connect/options.py @@ -4,8 +4,9 @@ from pydantic import BaseModel, ConfigDict, Field +from connect.client_interceptor import ClientInterceptor +from connect.handler_interceptor import HandlerInterceptor from connect.idempotency_level import IdempotencyLevel -from connect.interceptor import Interceptor class ConnectOptions(BaseModel): @@ -13,7 +14,7 @@ class ConnectOptions(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - interceptors: list[Interceptor] = Field(default=[]) + interceptors: list[HandlerInterceptor] = Field(default=[]) """A list of interceptors to apply to the handler.""" descriptor: Any = Field(default="") @@ -60,7 +61,7 @@ class ClientOptions(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - interceptors: list[Interceptor] = Field(default=[]) + interceptors: list[ClientInterceptor] = Field(default=[]) """A list of interceptors to apply to the handler.""" descriptor: Any = Field(default="") diff --git a/tests/test_streaming_connect_client.py b/tests/test_streaming_connect_client.py index 999491a..25c5e39 100644 --- a/tests/test_streaming_connect_client.py +++ b/tests/test_streaming_connect_client.py @@ -8,13 +8,14 @@ import pytest +from connect.call_options import CallOptions from connect.client import Client +from connect.client_interceptor import ClientInterceptor, StreamFunc from connect.code import Code from connect.connect import StreamRequest, StreamResponse, StreamType from connect.connection_pool import AsyncConnectionPool from connect.envelope import Envelope, EnvelopeFlags from connect.error import ConnectError -from connect.interceptor import Interceptor, StreamFunc from connect.options import ClientOptions from tests.conftest import ASGIRequest, Receive, Scope, Send, ServerConfig from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse @@ -438,9 +439,9 @@ async def test_server_streaming_interceptor(hypercorn_server: ServerConfig) -> N ephemeral_files: list[io.BufferedRandom] = [] - class FileInterceptor1(Interceptor): + class FileInterceptor1(ClientInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], call_options: CallOptions) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -450,13 +451,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 1") - return await next(request) + return await next(request, call_options) return _wrapped - class FileInterceptor2(Interceptor): + class FileInterceptor2(ClientInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], call_options: CallOptions) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -466,7 +467,7 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 2") - return await next(request) + return await next(request, call_options) return _wrapped @@ -597,9 +598,9 @@ async def test_client_streaming_interceptor(hypercorn_server: ServerConfig) -> N ephemeral_files: list[io.BufferedRandom] = [] - class FileInterceptor1(Interceptor): + class FileInterceptor1(ClientInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], call_options: CallOptions) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -609,13 +610,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 1") - return await next(request) + return await next(request, call_options) return _wrapped - class FileInterceptor2(Interceptor): + class FileInterceptor2(ClientInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], call_options: CallOptions) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -625,7 +626,7 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 2") - return await next(request) + return await next(request, call_options) return _wrapped diff --git a/tests/test_streaming_connect_server.py b/tests/test_streaming_connect_server.py index 960cd30..0d42d76 100644 --- a/tests/test_streaming_connect_server.py +++ b/tests/test_streaming_connect_server.py @@ -9,8 +9,9 @@ from connect.connect import StreamRequest, StreamResponse, StreamType from connect.envelope import Envelope, EnvelopeFlags from connect.error import ConnectError +from connect.handler_context import HandlerContext +from connect.handler_interceptor import HandlerInterceptor, StreamFunc from connect.headers import Headers -from connect.interceptor import Interceptor, StreamFunc from connect.options import ConnectOptions from tests.conftest import AsyncClient from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse @@ -24,7 +25,9 @@ @pytest.mark.asyncio() async def test_server_streaming() -> None: class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += " " + data.name @@ -76,7 +79,9 @@ def to_bytes() -> bytes: @pytest.mark.asyncio() async def test_server_streaming_end_stream_error() -> None: class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += " " + data.name @@ -131,7 +136,9 @@ def to_bytes() -> bytes: @pytest.mark.asyncio() async def test_server_streaming_response_envelope_message_compression() -> None: class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += " " + data.name @@ -190,7 +197,9 @@ def to_bytes() -> bytes: @pytest.mark.asyncio() async def test_server_streaming_request_envelope_message_compression() -> None: class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += " " + data.name @@ -254,7 +263,9 @@ def to_bytes() -> bytes: @pytest.mark.asyncio() async def test_server_streaming_invalid_request_envelope_message_compression() -> None: class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += " " + data.name @@ -322,7 +333,9 @@ async def test_server_streaming_interceptor() -> None: import tempfile class PingService(PingServiceHandler): - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingServerStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: async def iterator() -> AsyncIterator[PingResponse]: for i in range(3): yield PingResponse(name=f"Hello {i}!") @@ -335,9 +348,9 @@ def to_bytes() -> bytes: ephemeral_files: list[io.BufferedRandom] = [] - class FileInterceptor1(Interceptor): + class FileInterceptor1(HandlerInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -347,13 +360,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 1") - return await next(request) + return await next(request, context) return _wrapped - class FileInterceptor2(Interceptor): + class FileInterceptor2(HandlerInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -363,7 +376,7 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 2") - return await next(request) + return await next(request, context) return _wrapped @@ -395,7 +408,9 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: @pytest.mark.asyncio() async def test_client_streaming() -> None: class PingService(PingServiceHandler): - async def PingClientStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingClientStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += data.name @@ -450,7 +465,9 @@ async def test_client_streaming_interceptor() -> None: import tempfile class PingService(PingServiceHandler): - async def PingClientStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: + async def PingClientStream( + self, request: StreamRequest[PingRequest], context: HandlerContext + ) -> StreamResponse[PingResponse]: messages = "" async for data in request.messages: messages += data.name @@ -467,9 +484,9 @@ async def iter_bytes() -> AsyncIterator[bytes]: ephemeral_files: list[io.BufferedRandom] = [] - class FileInterceptor1(Interceptor): + class FileInterceptor1(HandlerInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -479,13 +496,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 1") - return await next(request) + return await next(request, context) return _wrapped - class FileInterceptor2(Interceptor): + class FileInterceptor2(HandlerInterceptor): def wrap_stream(self, next: StreamFunc) -> StreamFunc: - async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: + async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> StreamResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -495,7 +512,7 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 2") - return await next(request) + return await next(request, context) return _wrapped diff --git a/tests/test_unary_connect_client.py b/tests/test_unary_connect_client.py index 04c4ce8..9543fb4 100644 --- a/tests/test_unary_connect_client.py +++ b/tests/test_unary_connect_client.py @@ -7,13 +7,14 @@ import pytest +from connect.call_options import CallOptions from connect.client import Client +from connect.client_interceptor import ClientInterceptor, UnaryFunc from connect.code import Code from connect.connect import StreamType, UnaryRequest, UnaryResponse from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.idempotency_level import IdempotencyLevel -from connect.interceptor import Interceptor, UnaryFunc from connect.options import ClientOptions from tests.conftest import ASGIRequest, Receive, Scope, Send, ServerConfig from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse @@ -357,11 +358,11 @@ async def test_post_interceptor(hypercorn_server: ServerConfig) -> None: ephemeral_files: list[io.BufferedRandom] = [] - class FileInterceptor1(Interceptor): + class FileInterceptor1(ClientInterceptor): def wrap_unary(self, next: UnaryFunc) -> UnaryFunc: """Wrap a unary function with the interceptor.""" - async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: + async def _wrapped(request: UnaryRequest[Any], call_options: CallOptions) -> UnaryResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -371,15 +372,15 @@ async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 1") - return await next(request) + return await next(request, call_options) return _wrapped - class FileInterceptor2(Interceptor): + class FileInterceptor2(ClientInterceptor): def wrap_unary(self, next: UnaryFunc) -> UnaryFunc: """Wrap a unary function with the interceptor.""" - async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: + async def _wrapped(request: UnaryRequest[Any], call_options: CallOptions) -> UnaryResponse[Any]: nonlocal ephemeral_files fp = tempfile.TemporaryFile() # noqa: SIM115 @@ -389,7 +390,7 @@ async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: ephemeral_files.append(fp) fp.write(b"interceptor: 2") - return await next(request) + return await next(request, call_options) return _wrapped diff --git a/tests/test_unary_connect_server.py b/tests/test_unary_connect_server.py index cfb6404..8ba7ec9 100644 --- a/tests/test_unary_connect_server.py +++ b/tests/test_unary_connect_server.py @@ -8,6 +8,7 @@ import pytest from connect.connect import UnaryRequest, UnaryResponse +from connect.handler_context import HandlerContext from connect.idempotency_level import IdempotencyLevel from connect.options import ConnectOptions from tests.conftest import AsyncClient @@ -20,7 +21,9 @@ async def test_post_application_proto() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -45,7 +48,9 @@ async def test_post_application_json() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -67,7 +72,9 @@ async def test_post_gzip_compression() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -95,7 +102,9 @@ async def test_post_only_accept_encoding_gzip() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -121,7 +130,9 @@ async def test_get() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -150,7 +161,9 @@ async def test_get_base64() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message @@ -180,7 +193,9 @@ async def test_unsupported_raw_deflate_compression() -> None: class PingService(PingServiceHandler): """Ping service implementation.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: + async def Ping( + self, request: UnaryRequest[PingRequest], context: HandlerContext + ) -> UnaryResponse[PingResponse]: """Return a ping response.""" data = request.message diff --git a/tests/testdata/ping/v1/v1connect/ping_connect.py b/tests/testdata/ping/v1/v1connect/ping_connect.py index 5edb4bb..545bfa4 100644 --- a/tests/testdata/ping/v1/v1connect/ping_connect.py +++ b/tests/testdata/ping/v1/v1connect/ping_connect.py @@ -7,6 +7,7 @@ from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler +from connect.handler_context import HandlerContext from connect.options import ConnectOptions from tests.testdata.ping.v1 import ping_pb2 from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse @@ -34,11 +35,11 @@ class PingServiceProcedures(Enum): class PingServiceHandler(metaclass=abc.ABCMeta): """Handler for the ping service.""" - async def Ping(self, request: UnaryRequest[PingRequest]) -> UnaryResponse[PingResponse]: ... + async def Ping(self, request: UnaryRequest[PingRequest], context: HandlerContext) -> UnaryResponse[PingResponse]: ... - async def PingServerStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: ... + async def PingServerStream(self, request: StreamRequest[PingRequest], context: HandlerContext) -> StreamResponse[PingResponse]: ... - async def PingClientStream(self, request: StreamRequest[PingRequest]) -> StreamResponse[PingResponse]: ... + async def PingClientStream(self, request: StreamRequest[PingRequest], context: HandlerContext) -> StreamResponse[PingResponse]: ... def create_PingService_handlers(service: PingServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: From 18e66876e2cd88ca7f2dfcc752ce4ee067f47520 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 16 May 2025 20:54:08 +0900 Subject: [PATCH 3/3] connect: add doc --- src/connect/call_options.py | 4 ++++ src/connect/client.py | 4 ++++ src/connect/handler_context.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/src/connect/call_options.py b/src/connect/call_options.py index 3a6462c..adf3f51 100644 --- a/src/connect/call_options.py +++ b/src/connect/call_options.py @@ -1,9 +1,13 @@ +"""Options and configuration for making calls, including timeout and abort event support.""" + import asyncio from pydantic import BaseModel, ConfigDict, Field class CallOptions(BaseModel): + """Options for configuring a call, such as timeout and abort event.""" + model_config = ConfigDict(arbitrary_types_allowed=True) timeout: float | None = Field(default=None) diff --git a/src/connect/client.py b/src/connect/client.py index 211b0e1..6eb2e68 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -316,6 +316,7 @@ async def call_unary( Args: request (UnaryRequest[T_Request]): The request object containing the data to be sent to the server. + call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. Returns: UnaryResponse[T_Response]: The response object containing the data received from the server. @@ -332,6 +333,7 @@ async def call_server_stream( Args: request (StreamRequest[T_Request]): The request object containing the data to be sent to the server. + call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. Yields: StreamResponse[T_Response]: The response objects received from the server. @@ -361,6 +363,7 @@ async def call_client_stream( Args: request (StreamRequest[T_Request]): The request object containing the client-streaming data to be sent to the server. + call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. Yields: StreamResponse[T_Response]: An asynchronous generator that yields @@ -393,6 +396,7 @@ async def call_bidi_stream( Args: request (StreamRequest[T_Request]): The stream request object containing the data to be sent to the server. + call_options (CallOptions | None, optional): Optional call options for Yields: StreamResponse[T_Response]: The stream response object received from the server. diff --git a/src/connect/handler_context.py b/src/connect/handler_context.py index 44c26b9..522a6cd 100644 --- a/src/connect/handler_context.py +++ b/src/connect/handler_context.py @@ -1,10 +1,25 @@ +"""Provides the HandlerContext class for managing operation timeouts and tracking remaining time.""" + import time class HandlerContext: + """HandlerContext manages an optional timeout for operations, allowing tracking of the remaining time until a deadline. + + Attributes: + _deadline (float | None): The UNIX timestamp representing the deadline, or None if no timeout is set. + + """ + _deadline: float | None def __init__(self, timeout: float | None) -> None: + """Initialize HandlerContext with an optional timeout. + + Args: + timeout (float | None): The timeout duration in seconds, or None for no timeout. + + """ self._deadline = time.time() + timeout if timeout else None def timeout_remaining(self) -> float | None: