Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions conformance/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import AsyncGenerator
from typing import Any

from connect.call_options import CallOptions
from connect.connect import StreamRequest, UnaryRequest
from connect.connection_pool import AsyncConnectionPool
from connect.error import ConnectError
Expand Down Expand Up @@ -253,6 +254,8 @@ async def delayed_abort() -> None:
UnaryRequest(
content=req,
headers=headers,
),
CallOptions(
timeout=msg.timeout_ms / 1000,
abort_event=abort_event,
),
Expand Down Expand Up @@ -290,8 +293,10 @@ async def delayed_abort() -> None:
asyncio.create_task(delayed_abort())

async with getattr(client, msg.method)(
StreamRequest(
content=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
StreamRequest(content=_reqs(), headers=headers),
CallOptions(
timeout=msg.timeout_ms / 1000,
abort_event=abort_event,
),
) as resp:
async for message in resp.messages:
Expand All @@ -314,8 +319,10 @@ async def delayed_abort() -> None:
headers = to_connect_headers(msg.request_headers)

async with getattr(client, msg.method)(
StreamRequest(
content=reqs, headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
StreamRequest(content=reqs, headers=headers),
CallOptions(
timeout=msg.timeout_ms / 1000,
abort_event=abort_event,
),
) as resp:
if msg.cancel.HasField("after_close_send_ms"):
Expand Down Expand Up @@ -356,8 +363,10 @@ async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]:
headers = to_connect_headers(msg.request_headers)

async with getattr(client, msg.method)(
StreamRequest(
content=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
StreamRequest(content=_reqs(), headers=headers),
CallOptions(
timeout=msg.timeout_ms / 1000,
abort_event=abort_event,
),
) as resp:
if msg.cancel.HasField("before_close_send"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from connect.client import Client
import connect.connect
from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler, BidiStreamHandler
from connect.handler_context import HandlerContext
from connect.options import ClientOptions, ConnectOptions
from connect.connection_pool import AsyncConnectionPool
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
Expand Down Expand Up @@ -66,22 +67,22 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti
class ConformanceServiceHandler:
"""Handler for the conformanceService service."""

async def Unary(self, request: connect.connect.UnaryRequest[UnaryRequest]) -> connect.connect.UnaryResponse[UnaryResponse]:
async def Unary(self, request: connect.connect.UnaryRequest[UnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnaryResponse]:
raise NotImplementedError()

async def ServerStream(self, request: connect.connect.StreamRequest[ServerStreamRequest]) -> connect.connect.StreamResponse[ServerStreamResponse]:
async def ServerStream(self, request: connect.connect.StreamRequest[ServerStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ServerStreamResponse]:
raise NotImplementedError()

async def ClientStream(self, request: connect.connect.StreamRequest[ClientStreamRequest]) -> connect.connect.StreamResponse[ClientStreamResponse]:
async def ClientStream(self, request: connect.connect.StreamRequest[ClientStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ClientStreamResponse]:
raise NotImplementedError()

async def BidiStream(self, request: connect.connect.StreamRequest[BidiStreamRequest]) -> connect.connect.StreamResponse[BidiStreamResponse]:
async def BidiStream(self, request: connect.connect.StreamRequest[BidiStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[BidiStreamResponse]:
raise NotImplementedError()

async def Unimplemented(self, request: connect.connect.UnaryRequest[UnimplementedRequest]) -> connect.connect.UnaryResponse[UnimplementedResponse]:
async def Unimplemented(self, request: connect.connect.UnaryRequest[UnimplementedRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnimplementedResponse]:
raise NotImplementedError()

async def IdempotentUnary(self, request: connect.connect.UnaryRequest[IdempotentUnaryRequest]) -> connect.connect.UnaryResponse[IdempotentUnaryResponse]:
async def IdempotentUnary(self, request: connect.connect.UnaryRequest[IdempotentUnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[IdempotentUnaryResponse]:
raise NotImplementedError()


Expand Down
32 changes: 21 additions & 11 deletions conformance/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from connect.code import Code
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse
from connect.error import ConnectError, ErrorDetail
from connect.handler_context import HandlerContext
from connect.headers import Headers
from connect.middleware import ConnectMiddleware
from starlette.applications import Starlette
Expand Down Expand Up @@ -95,7 +96,9 @@ def code_from_pb_code(code: config_pb2.Code) -> Code:
class ConformanceService(ConformanceServiceHandler):
"""ConformanceService is a service handler that implements various gRPC methods for testing conformance."""

async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryResponse[service_pb2.UnaryResponse]:
async def Unary(
self, request: UnaryRequest[service_pb2.UnaryRequest], context: HandlerContext
) -> UnaryResponse[service_pb2.UnaryResponse]:
"""Handle a unary gRPC request and generates a response based on the provided request definition.

Args:
Expand Down Expand Up @@ -128,10 +131,12 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR
request_any = any_pb2.Any()
request_any.Pack(request.message)

timeout_sec = context.timeout_remaining()

request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=[request_any],
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=pb_query_params_from_peer_query(request.peer.query),
),
Expand Down Expand Up @@ -175,7 +180,7 @@ async def Unary(self, request: UnaryRequest[service_pb2.UnaryRequest]) -> UnaryR
return UnaryResponse(content=service_pb2.UnaryResponse(payload=payload), headers=headers, trailers=trailers)

async def IdempotentUnary(
self, request: UnaryRequest[service_pb2.IdempotentUnaryRequest]
self, request: UnaryRequest[service_pb2.IdempotentUnaryRequest], context: HandlerContext
) -> UnaryResponse[service_pb2.IdempotentUnaryResponse]:
"""Handle the IdempotentUnary RPC call.

Expand All @@ -202,10 +207,11 @@ async def IdempotentUnary(
request_any = any_pb2.Any()
request_any.Pack(request.message)

timeout_sec = context.timeout_remaining()
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=[request_any],
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=pb_query_params_from_peer_query(request.peer.query),
),
Expand Down Expand Up @@ -251,7 +257,7 @@ async def IdempotentUnary(
)

async def ClientStream(
self, request: StreamRequest[service_pb2.ClientStreamRequest]
self, request: StreamRequest[service_pb2.ClientStreamRequest], context: HandlerContext
) -> StreamResponse[service_pb2.ClientStreamResponse]:
"""Handle a bidirectional streaming RPC where the client sends a stream of `ClientStreamRequest` messages and receives a single `ClientStreamResponse` message.

Expand Down Expand Up @@ -287,10 +293,11 @@ async def ClientStream(
message_any.Pack(message)
messages.append(message_any)

timeout_sec = context.timeout_remaining()
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=pb_query_params_from_peer_query(request.peer.query),
),
Expand Down Expand Up @@ -339,7 +346,7 @@ async def ClientStream(
)

async def ServerStream(
self, request: StreamRequest[service_pb2.ServerStreamRequest]
self, request: StreamRequest[service_pb2.ServerStreamRequest], context: HandlerContext
) -> StreamResponse[service_pb2.ServerStreamResponse]:
"""Handle a server-side streaming RPC call.

Expand Down Expand Up @@ -379,10 +386,11 @@ async def ServerStream(
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

timeout_sec = context.timeout_remaining()
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
connect_get_info=service_pb2.ConformancePayload.ConnectGetInfo(
query_params=pb_query_params_from_peer_query(request.peer.query),
),
Expand Down Expand Up @@ -442,7 +450,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.ServerStreamResponse]:
)

async def BidiStream(
self, request: StreamRequest[service_pb2.BidiStreamRequest]
self, request: StreamRequest[service_pb2.BidiStreamRequest], context: HandlerContext
) -> StreamResponse[service_pb2.BidiStreamResponse]:
"""Handle a bidirectional streaming RPC.

Expand Down Expand Up @@ -488,6 +496,8 @@ async def BidiStream(
headers = headers_from_pb_headers(response_definition.response_headers)
trailers = headers_from_pb_headers(response_definition.response_trailers)

timeout_sec = context.timeout_remaining()

async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:
nonlocal response_index

Expand All @@ -496,7 +506,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
)
else:
request_info = None
Expand Down Expand Up @@ -525,7 +535,7 @@ async def iterator() -> typing.AsyncIterator[service_pb2.BidiStreamResponse]:
request_info = service_pb2.ConformancePayload.RequestInfo(
request_headers=pb_headers_from_headers(request.headers),
requests=messages,
timeout_ms=int(request.timeout) if request.timeout else None,
timeout_ms=int(timeout_sec * 1000) if timeout_sec else None,
)

detail = any_pb2.Any()
Expand Down
9 changes: 5 additions & 4 deletions examples/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import hypercorn.asyncio
from connect.connect import UnaryRequest, UnaryResponse
from connect.interceptor import Interceptor, UnaryFunc
from connect.handler_context import HandlerContext
from connect.handler_interceptor import HandlerInterceptor, UnaryFunc
from connect.middleware import ConnectMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
Expand All @@ -23,13 +24,13 @@ async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayRespo
return UnaryResponse(SayResponse(sentence=data.sentence))


class IPRestrictionInterceptor(Interceptor):
class IPRestrictionInterceptor(HandlerInterceptor):
"""IP restriction interceptor."""

def wrap_unary(self, next: UnaryFunc) -> UnaryFunc:
"""Wrap a unary function with the interceptor."""

async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]:
async def _wrapped(request: UnaryRequest[Any], context: HandlerContext) -> UnaryResponse[Any]:
ip_allow_list = os.environ.get("IP_ALLOW_LIST", "").split(",")
if not ip_allow_list:
raise Exception("White list not found")
Expand All @@ -45,7 +46,7 @@ async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]:
if ip not in ip_allow_list:
raise Exception("IP not allowed")

return await next(request)
return await next(request, context)

return _wrapped

Expand Down
17 changes: 17 additions & 0 deletions src/connect/call_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Options and configuration for making calls, including timeout and abort event support."""

import asyncio

from pydantic import BaseModel, ConfigDict, Field


class CallOptions(BaseModel):
"""Options for configuring a call, such as timeout and abort event."""

model_config = ConfigDict(arbitrary_types_allowed=True)

timeout: float | None = Field(default=None)
"""Timeout for the call in seconds."""

abort_event: asyncio.Event | None = Field(default=None)
"""Event to abort the call."""
Loading
Loading