From 40b357141b6bc17a1becd463cfba0b73243de64e Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 17:15:03 +0900 Subject: [PATCH] connect: add timeout to request property --- conformance/server.py | 90 +++++++++++++++++++++--------------------- src/connect/handler.py | 47 ++++++++++++++++------ 2 files changed, 80 insertions(+), 57 deletions(-) diff --git a/conformance/server.py b/conformance/server.py index 65830d8..2359013 100644 --- a/conformance/server.py +++ b/conformance/server.py @@ -23,7 +23,7 @@ logger = logging.getLogger("conformance.server") -def headers_from_svc_headers(headers: typing.Iterable[service_pb2.Header]) -> Headers: +def headers_from_pb_headers(headers: typing.Iterable[service_pb2.Header]) -> Headers: """Convert a list of headers to a Headers object.""" header = Headers() for h in headers: @@ -34,7 +34,7 @@ def headers_from_svc_headers(headers: typing.Iterable[service_pb2.Header]) -> He return header -def svc_headers_from_headers(headers: Headers) -> list[service_pb2.Header]: +def pb_headers_from_headers(headers: Headers) -> list[service_pb2.Header]: """Convert a Headers object to a list of headers.""" svc_headers = [] for key, value in headers.items(): @@ -43,7 +43,7 @@ def svc_headers_from_headers(headers: Headers) -> list[service_pb2.Header]: return svc_headers -def svc_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[service_pb2.Header]: +def pb_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[service_pb2.Header]: """Convert a query mapping to a list of headers.""" svc_query_params = [] for key, value in query.items(): @@ -52,7 +52,7 @@ def svc_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[se return svc_query_params -def code_from_svc_code(code: config_pb2.Code) -> Code: +def code_from_pb_code(code: config_pb2.Code) -> Code: """Convert a service code to a Connect code.""" match code: case config_pb2.CODE_UNSPECIFIED: @@ -131,11 +131,11 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR request_any.Pack(request.message) request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=[request_any], - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( - query_params=svc_query_params_from_peer_query(request.peer.query), + query_params=pb_query_params_from_peer_query(request.peer.query), ), ) @@ -145,8 +145,8 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR detail.Pack(request_info) response_definition.error.details.append(detail) - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) metadata = Headers() metadata.update(headers) @@ -154,7 +154,7 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR error = ConnectError( message=response_definition.error.message, - code=code_from_svc_code(response_definition.error.code), + code=code_from_pb_code(response_definition.error.code), details=[ErrorDetail(pb_any=error) for error in response_definition.error.details], metadata=metadata, ) @@ -165,8 +165,8 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR ) if response_definition: - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) if response_definition.response_delay_ms: await asyncio.sleep(response_definition.response_delay_ms / 1000) @@ -212,11 +212,11 @@ async def IdempotentUnary( request_any.Pack(request.message) request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=[request_any], - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( - query_params=svc_query_params_from_peer_query(request.peer.query), + query_params=pb_query_params_from_peer_query(request.peer.query), ), ) @@ -226,8 +226,8 @@ async def IdempotentUnary( detail.Pack(request_info) response_definition.error.details.append(detail) - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) metadata = Headers() metadata.update(headers) @@ -235,7 +235,7 @@ async def IdempotentUnary( error = ConnectError( message=response_definition.error.message, - code=code_from_svc_code(response_definition.error.code), + code=code_from_pb_code(response_definition.error.code), details=[ErrorDetail(pb_any=error) for error in response_definition.error.details], metadata=metadata, ) @@ -246,8 +246,8 @@ async def IdempotentUnary( ) if response_definition: - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) if response_definition.response_delay_ms: await asyncio.sleep(response_definition.response_delay_ms / 1000) @@ -304,11 +304,11 @@ async def ClientStream( messages.append(message_any) request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( - query_params=svc_query_params_from_peer_query(request.peer.query), + query_params=pb_query_params_from_peer_query(request.peer.query), ), ) @@ -319,8 +319,8 @@ async def ClientStream( detail.Pack(request_info) response_definition.error.details.append(detail) - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) metadata = Headers() metadata.update(headers) @@ -328,7 +328,7 @@ async def ClientStream( error = ConnectError( message=response_definition.error.message, - code=code_from_svc_code(response_definition.error.code), + code=code_from_pb_code(response_definition.error.code), details=[ErrorDetail(pb_any=error) for error in response_definition.error.details], metadata=metadata, ) @@ -339,8 +339,8 @@ async def ClientStream( if response_definition: payload.data = response_definition.response_data - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) if response_definition and response_definition.response_delay_ms: await asyncio.sleep(response_definition.response_delay_ms / 1000) @@ -399,15 +399,15 @@ async def ServerStream( headers = None trailers = None if response_definition: - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo( - query_params=svc_query_params_from_peer_query(request.peer.query), + query_params=pb_query_params_from_peer_query(request.peer.query), ), ) @@ -437,8 +437,8 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]: ) if response_definition.HasField("error"): - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) metadata = Headers() metadata.update(headers) @@ -451,7 +451,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]: error = ConnectError( message=response_definition.error.message, - code=code_from_svc_code(response_definition.error.code), + code=code_from_pb_code(response_definition.error.code), details=[ErrorDetail(pb_any=error) for error in response_definition.error.details], metadata=metadata, ) @@ -515,8 +515,8 @@ async def BidiStream( first_response = False if response_definition: - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: nonlocal response_index @@ -524,9 +524,9 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: while response_definition and response_index < len(response_definition.response_data): if response_index == 0: request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, ) else: request_info = None @@ -544,8 +544,8 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: yield response if response_definition and response_definition.HasField("error"): - headers = headers_from_svc_headers(response_definition.response_headers) - trailers = headers_from_svc_headers(response_definition.response_trailers) + headers = headers_from_pb_headers(response_definition.response_headers) + trailers = headers_from_pb_headers(response_definition.response_trailers) metadata = Headers() metadata.update(headers) @@ -553,9 +553,9 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: if response_index == 0: request_info = service_pb2.ConformancePayload.RequestInfo( - request_headers=svc_headers_from_headers(request.headers), + request_headers=pb_headers_from_headers(request.headers), requests=messages, - timeout_ms=None, + timeout_ms=int(request.timeout) if request.timeout else None, ) detail = any_pb2.Any() @@ -564,7 +564,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]: error = ConnectError( message=response_definition.error.message, - code=code_from_svc_code(response_definition.error.code), + code=code_from_pb_code(response_definition.error.code), details=[ErrorDetail(pb_any=error) for error in response_definition.error.details], metadata=metadata, ) diff --git a/src/connect/handler.py b/src/connect/handler.py index 882a016..12f3910 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -148,8 +148,8 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: return handlers -UnaryImplementationFunc = Callable[[UnaryHandlerConn], Awaitable[None]] -StreamImplementationFunc = Callable[[StreamingHandlerConn], Awaitable[None]] +UnaryImplementationFunc = Callable[[UnaryHandlerConn, float | None], Awaitable[None]] +StreamImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] class Handler: @@ -310,7 +310,7 @@ def is_stream( """ signature = inspect.signature(impl) parameters = signature.parameters - return len(parameters) == 1 and next(iter(parameters.values())).annotation == StreamingHandlerConn + return len(parameters) == 2 and next(iter(parameters.values())).annotation == StreamingHandlerConn def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]: """Determine if the given implementation function is a unary implementation. @@ -324,7 +324,7 @@ def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> """ signature = inspect.signature(impl) parameters = signature.parameters - return len(parameters) == 1 and next(iter(parameters.values())).annotation == UnaryHandlerConn + return len(parameters) == 2 and next(iter(parameters.values())).annotation == UnaryHandlerConn async def stream_handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter @@ -352,8 +352,21 @@ async def stream_handle( implementation = self.implementation if not self.is_stream(implementation): raise ValueError(f"Invalid function type for stream handler: {implementation}") + + timeout = request.headers.get(CONNECT_HEADER_TIMEOUT, None) try: - await implementation(conn) + if timeout: + try: + timeout_ms = int(timeout) + timeout_sec = timeout_ms / 1000 + except ValueError as e: + raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e + + with anyio.fail_after(delay=timeout_sec): + await implementation(conn, timeout_ms) + else: + await implementation(conn, None) + except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -390,14 +403,15 @@ async def unary_handle( try: if timeout: try: - timeout_sec = float(timeout) / 1000 + timeout_ms = int(timeout) + timeout_sec = timeout_ms / 1000 except ValueError as e: raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e with anyio.fail_after(delay=timeout_sec): - await implementation(conn) + await implementation(conn, timeout_ms) else: - await implementation(conn) + await implementation(conn, None) except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -454,8 +468,11 @@ async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: UnaryHandlerConn) -> None: + async def implementation(conn: UnaryHandlerConn, timeout: float | None) -> None: request = await receive_unary_request(conn, input) + if timeout: + request.timeout = timeout + response = await untyped(request) if not isinstance(response.message, output): @@ -523,8 +540,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn) -> None: + async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: request = await receive_stream_request(conn, input) + if timeout: + request.timeout = timeout response = await untyped(request) @@ -596,8 +615,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn) -> None: + async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: request = await receive_stream_request(conn, input) + if timeout: + request.timeout = timeout response = await untyped(request) @@ -688,8 +709,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn) -> None: + async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: request = await receive_stream_request(conn, input) + if timeout: + request.timeout = timeout response = await untyped(request)