Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions conformance/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger("conformance.server")


def headers_from_svc_headers(headers: typing.Iterable[service_pb2.Header]) -> Headers:
def headers_from_pb_headers(headers: typing.Iterable[service_pb2.Header]) -> Headers:
"""Convert a list of headers to a Headers object."""
header = Headers()
for h in headers:
Expand All @@ -34,7 +34,7 @@ def headers_from_svc_headers(headers: typing.Iterable[service_pb2.Header]) -> He
return header


def svc_headers_from_headers(headers: Headers) -> list[service_pb2.Header]:
def pb_headers_from_headers(headers: Headers) -> list[service_pb2.Header]:
"""Convert a Headers object to a list of headers."""
svc_headers = []
for key, value in headers.items():
Expand All @@ -43,7 +43,7 @@ def svc_headers_from_headers(headers: Headers) -> list[service_pb2.Header]:
return svc_headers


def svc_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[service_pb2.Header]:
def pb_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[service_pb2.Header]:
"""Convert a query mapping to a list of headers."""
svc_query_params = []
for key, value in query.items():
Expand All @@ -52,7 +52,7 @@ def svc_query_params_from_peer_query(query: typing.Mapping[str, str]) -> list[se
return svc_query_params


def code_from_svc_code(code: config_pb2.Code) -> Code:
def code_from_pb_code(code: config_pb2.Code) -> Code:
"""Convert a service code to a Connect code."""
match code:
case config_pb2.CODE_UNSPECIFIED:
Expand Down Expand Up @@ -131,11 +131,11 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR
request_any.Pack(request.message)

request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=[request_any],
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=svc_query_params_from_peer_query(request.peer.query),
query_params=pb_query_params_from_peer_query(request.peer.query),
),
)

Expand All @@ -145,16 +145,16 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR
detail.Pack(request_info)
response_definition.error.details.append(detail)

headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

metadata = Headers()
metadata.update(headers)
metadata.update(trailers)

error = ConnectError(
message=response_definition.error.message,
code=code_from_svc_code(response_definition.error.code),
code=code_from_pb_code(response_definition.error.code),
details=[ErrorDetail(pb_any=error) for error in response_definition.error.details],
metadata=metadata,
)
Expand All @@ -165,8 +165,8 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR
)

if response_definition:
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

if response_definition.response_delay_ms:
await asyncio.sleep(response_definition.response_delay_ms / 1000)
Expand Down Expand Up @@ -212,11 +212,11 @@ async def IdempotentUnary(
request_any.Pack(request.message)

request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=[request_any],
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=svc_query_params_from_peer_query(request.peer.query),
query_params=pb_query_params_from_peer_query(request.peer.query),
),
)

Expand All @@ -226,16 +226,16 @@ async def IdempotentUnary(
detail.Pack(request_info)
response_definition.error.details.append(detail)

headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

metadata = Headers()
metadata.update(headers)
metadata.update(trailers)

error = ConnectError(
message=response_definition.error.message,
code=code_from_svc_code(response_definition.error.code),
code=code_from_pb_code(response_definition.error.code),
details=[ErrorDetail(pb_any=error) for error in response_definition.error.details],
metadata=metadata,
)
Expand All @@ -246,8 +246,8 @@ async def IdempotentUnary(
)

if response_definition:
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

if response_definition.response_delay_ms:
await asyncio.sleep(response_definition.response_delay_ms / 1000)
Expand Down Expand Up @@ -304,11 +304,11 @@ async def ClientStream(
messages.append(message_any)

request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=svc_query_params_from_peer_query(request.peer.query),
query_params=pb_query_params_from_peer_query(request.peer.query),
),
)

Expand All @@ -319,16 +319,16 @@ async def ClientStream(
detail.Pack(request_info)
response_definition.error.details.append(detail)

headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

metadata = Headers()
metadata.update(headers)
metadata.update(trailers)

error = ConnectError(
message=response_definition.error.message,
code=code_from_svc_code(response_definition.error.code),
code=code_from_pb_code(response_definition.error.code),
details=[ErrorDetail(pb_any=error) for error in response_definition.error.details],
metadata=metadata,
)
Expand All @@ -339,8 +339,8 @@ async def ClientStream(

if response_definition:
payload.data = response_definition.response_data
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

if response_definition and response_definition.response_delay_ms:
await asyncio.sleep(response_definition.response_delay_ms / 1000)
Expand Down Expand Up @@ -399,15 +399,15 @@ async def ServerStream(
headers = None
trailers = None
if response_definition:
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=svc_query_params_from_peer_query(request.peer.query),
query_params=pb_query_params_from_peer_query(request.peer.query),
),
)

Expand Down Expand Up @@ -437,8 +437,8 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]:
)

if response_definition.HasField("error"):
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

metadata = Headers()
metadata.update(headers)
Expand All @@ -451,7 +451,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]:

error = ConnectError(
message=response_definition.error.message,
code=code_from_svc_code(response_definition.error.code),
code=code_from_pb_code(response_definition.error.code),
details=[ErrorDetail(pb_any=error) for error in response_definition.error.details],
metadata=metadata,
)
Expand Down Expand Up @@ -515,18 +515,18 @@ async def BidiStream(
first_response = False

if response_definition:
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:
nonlocal response_index

while response_definition and response_index < len(response_definition.response_data):
if response_index == 0:
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
)
else:
request_info = None
Expand All @@ -544,18 +544,18 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:
yield response

if response_definition and response_definition.HasField("error"):
headers = headers_from_svc_headers(response_definition.response_headers)
trailers = headers_from_svc_headers(response_definition.response_trailers)
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

metadata = Headers()
metadata.update(headers)
metadata.update(trailers)

if response_index == 0:
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=svc_headers_from_headers(request.headers),
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=None,
timeout_ms=int(request.timeout) if request.timeout else None,
)

detail = any_pb2.Any()
Expand All @@ -564,7 +564,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:

error = ConnectError(
message=response_definition.error.message,
code=code_from_svc_code(response_definition.error.code),
code=code_from_pb_code(response_definition.error.code),
details=[ErrorDetail(pb_any=error) for error in response_definition.error.details],
metadata=metadata,
)
Expand Down
47 changes: 35 additions & 12 deletions src/connect/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]:
return handlers


UnaryImplementationFunc = Callable[[UnaryHandlerConn], Awaitable[None]]
StreamImplementationFunc = Callable[[StreamingHandlerConn], Awaitable[None]]
UnaryImplementationFunc = Callable[[UnaryHandlerConn, float | None], Awaitable[None]]
StreamImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]]


class Handler:
Expand Down Expand Up @@ -310,7 +310,7 @@ def is_stream(
"""
signature = inspect.signature(impl)
parameters = signature.parameters
return len(parameters) == 1 and next(iter(parameters.values())).annotation == StreamingHandlerConn
return len(parameters) == 2 and next(iter(parameters.values())).annotation == StreamingHandlerConn

def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]:
"""Determine if the given implementation function is a unary implementation.
Expand All @@ -324,7 +324,7 @@ def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) ->
"""
signature = inspect.signature(impl)
parameters = signature.parameters
return len(parameters) == 1 and next(iter(parameters.values())).annotation == UnaryHandlerConn
return len(parameters) == 2 and next(iter(parameters.values())).annotation == UnaryHandlerConn

async def stream_handle(
self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter
Expand Down Expand Up @@ -352,8 +352,21 @@ async def stream_handle(
implementation = self.implementation
if not self.is_stream(implementation):
raise ValueError(f"Invalid function type for stream handler: {implementation}")

timeout = request.headers.get(CONNECT_HEADER_TIMEOUT, None)
try:
await implementation(conn)
if timeout:
try:
timeout_ms = int(timeout)
timeout_sec = timeout_ms / 1000
except ValueError as e:
raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e

with anyio.fail_after(delay=timeout_sec):
await implementation(conn, timeout_ms)
else:
await implementation(conn, None)

except Exception as e:
error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL)

Expand Down Expand Up @@ -390,14 +403,15 @@ async def unary_handle(
try:
if timeout:
try:
timeout_sec = float(timeout) / 1000
timeout_ms = int(timeout)
timeout_sec = timeout_ms / 1000
except ValueError as e:
raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e

with anyio.fail_after(delay=timeout_sec):
await implementation(conn)
await implementation(conn, timeout_ms)
else:
await implementation(conn)
await implementation(conn, None)

except Exception as e:
error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL)
Expand Down Expand Up @@ -454,8 +468,11 @@ async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response

untyped = apply_interceptors(_untyped, options.interceptors)

async def implementation(conn: UnaryHandlerConn) -> None:
async def implementation(conn: UnaryHandlerConn, timeout: float | None) -> None:
request = await receive_unary_request(conn, input)
if timeout:
request.timeout = timeout

response = await untyped(request)

if not isinstance(response.message, output):
Expand Down Expand Up @@ -523,8 +540,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon

untyped = apply_interceptors(_untyped, options.interceptors)

async def implementation(conn: StreamingHandlerConn) -> None:
async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None:
request = await receive_stream_request(conn, input)
if timeout:
request.timeout = timeout

response = await untyped(request)

Expand Down Expand Up @@ -596,8 +615,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon

untyped = apply_interceptors(_untyped, options.interceptors)

async def implementation(conn: StreamingHandlerConn) -> None:
async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None:
request = await receive_stream_request(conn, input)
if timeout:
request.timeout = timeout

response = await untyped(request)

Expand Down Expand Up @@ -688,8 +709,10 @@ async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Respon

untyped = apply_interceptors(_untyped, options.interceptors)

async def implementation(conn: StreamingHandlerConn) -> None:
async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None:
request = await receive_stream_request(conn, input)
if timeout:
request.timeout = timeout

response = await untyped(request)

Expand Down
Loading