Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance/client_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ features:
- STREAM_TYPE_CLIENT_STREAM
- STREAM_TYPE_SERVER_STREAM
- STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM
- STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM
# - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM

supports_h2c: true
supports_tls: true
Expand Down
3 changes: 1 addition & 2 deletions conformance/client_known_failing.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
# Cancellation is not supported yet
Client Cancellation/**

175 changes: 133 additions & 42 deletions conformance/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import ssl
import struct
import sys
import time
import traceback
from collections.abc import AsyncGenerator
from typing import Any
Expand Down Expand Up @@ -142,6 +141,17 @@ def to_pb_headers(headers: Headers) -> list[service_pb2.Header]:
]


def to_connect_headers(pb_headers: RepeatedCompositeFieldContainer[service_pb2.Header]) -> Headers:
headers = Headers()
for h in pb_headers:
if key := headers.get(h.name.lower()):
headers[key] = f"{headers[key]}, {', '.join(h.value)}"
else:
headers[h.name.lower()] = ", ".join(h.value)

return headers


async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_compat_pb2.ClientCompatResponse:
"""Handle a client compatibility request and returns a response.

Expand Down Expand Up @@ -204,9 +214,6 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c

url = f"{proto}://{msg.host}:{msg.port}"

if msg.request_delay_ms > 0:
time.sleep(msg.request_delay_ms / 1000.0)

async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session:
payloads = []
try:
Expand All @@ -216,20 +223,28 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c

client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options)
if msg.stream_type == config_pb2.STREAM_TYPE_UNARY:
if msg.request_delay_ms > 0:
await asyncio.sleep(msg.request_delay_ms / 1000)

abort_event = asyncio.Event()
req = await anext(reqs)

header = Headers()
for h in msg.request_headers:
if key := header.get(h.name.lower()):
header[key] = f"{header[key]}, {', '.join(h.value)}"
else:
header[h.name.lower()] = ", ".join(h.value)
if msg.cancel.after_close_send_ms > 0:

async def delayed_abort() -> None:
await asyncio.sleep(msg.cancel.after_close_send_ms / 1000)
abort_event.set()

asyncio.create_task(delayed_abort())

headers = to_connect_headers(msg.request_headers)

resp = await getattr(client, msg.method)(
UnaryRequest(
message=req,
headers=header,
headers=headers,
timeout=msg.timeout_ms / 1000,
abort_event=abort_event,
),
)
payloads.append(resp.message.payload)
Expand All @@ -243,29 +258,116 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c
response_trailers=to_pb_headers(resp.trailers),
),
)
elif msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM:
abort_event = asyncio.Event()
headers = to_connect_headers(msg.request_headers)

async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]:
async for req in reqs:
if msg.request_delay_ms > 0:
await asyncio.sleep(msg.request_delay_ms / 1000)
yield req

if msg.cancel.HasField("before_close_send"):
abort_event.set()

if msg.cancel.HasField("after_close_send_ms"):

async def delayed_abort() -> None:
await asyncio.sleep(msg.cancel.after_close_send_ms / 1000)
abort_event.set()

asyncio.create_task(delayed_abort())

async with getattr(client, msg.method)(
StreamRequest(
messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
),
) as resp:
async for message in resp.messages:
payloads.append(message.payload)

return client_compat_pb2.ClientCompatResponse(
test_name=msg.test_name,
response=client_compat_pb2.ClientResponseResult(
payloads=payloads,
http_status_code=200,
response_headers=to_pb_headers(resp.headers),
response_trailers=to_pb_headers(resp.trailers),
),
)
elif msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM:
abort_event = asyncio.Event()
if msg.request_delay_ms > 0:
await asyncio.sleep(msg.request_delay_ms / 1000)

headers = to_connect_headers(msg.request_headers)

async with getattr(client, msg.method)(
StreamRequest(
messages=reqs, headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
),
) as resp:
if msg.cancel.HasField("after_close_send_ms"):

async def delayed_abort() -> None:
await asyncio.sleep(msg.cancel.after_close_send_ms / 1000)
abort_event.set()

asyncio.create_task(delayed_abort())

async for message in resp.messages:
payloads.append(message.payload)
if len(payloads) == msg.cancel.after_num_responses:
abort_event.set()

return client_compat_pb2.ClientCompatResponse(
test_name=msg.test_name,
response=client_compat_pb2.ClientResponseResult(
payloads=payloads,
http_status_code=200,
response_headers=to_pb_headers(resp.headers),
response_trailers=to_pb_headers(resp.trailers),
),
)

elif (
msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM
or msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM
or msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM
msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM
or msg.stream_type == config_pb2.STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM
):
header = Headers()
for h in msg.request_headers:
if key := header.get(h.name.lower()):
header[key] = f"{header[key]}, {', '.join(h.value)}"
else:
header[h.name.lower()] = ", ".join(h.value)
abort_event = asyncio.Event()

resp = await getattr(client, msg.method)(
async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]:
async for req in reqs:
if msg.request_delay_ms > 0:
await asyncio.sleep(msg.request_delay_ms / 1000)
yield req

headers = to_connect_headers(msg.request_headers)

async with getattr(client, msg.method)(
StreamRequest(
messages=reqs,
headers=header,
timeout=msg.timeout_ms / 1000,
messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event
),
)
) as resp:
if msg.cancel.HasField("before_close_send"):
abort_event.set()

async for message in resp.messages:
payloads.append(message.payload)
if msg.cancel.HasField("after_close_send_ms"):

async def delayed_abort() -> None:
await asyncio.sleep(msg.cancel.after_close_send_ms / 1000)
abort_event.set()

asyncio.create_task(delayed_abort())

if msg.cancel.HasField("after_num_responses") and msg.cancel.after_num_responses == 0:
abort_event.set()

async for message in resp.messages:
payloads.append(message.payload)
if len(payloads) == msg.cancel.after_num_responses:
abort_event.set()

return client_compat_pb2.ClientCompatResponse(
test_name=msg.test_name,
Expand Down Expand Up @@ -306,11 +408,6 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c
if "--debug" in sys.argv:
logging.debug("Debug mode enabled")

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

tasks = []

async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None:
"""Run the message handler for a given request."""
try:
Expand All @@ -325,16 +422,10 @@ async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None:

async def read_requests() -> None:
"""Read requests from standard input and process them asynchronously."""
loop = asyncio.get_event_loop()
while req := await loop.run_in_executor(None, read_request):
task = loop.create_task(run_message(req))
tasks.append(task)

loop.run_until_complete(read_requests())
await asyncio.sleep(0.01)

pending_tasks = [t for t in tasks if not t.done()]
if pending_tasks:
logger.info(f"Waiting for {len(pending_tasks)} pending tasks to complete...")
loop.run_until_complete(asyncio.gather(*pending_tasks))
loop.create_task(run_message(req))

logger.info("All done")
loop.close()
asyncio.run(read_requests())
2 changes: 2 additions & 0 deletions conformance/run-testcase.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Connect Unexpected Responses/HTTPVersion:2/TLS:true/client-stream/multiple-responses
Connect Unexpected Responses/HTTPVersion:2/TLS:false/client-stream/ok-but-no-response
94 changes: 65 additions & 29 deletions src/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
These classes allow making unary calls to a specified URL with given request and response types.
"""

from collections.abc import Awaitable, Callable
import contextlib
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any

import httpcore
Expand Down Expand Up @@ -227,7 +228,7 @@ def on_request_send(r: httpcore.Request) -> None:

conn.on_request_send(on_request_send)

await conn.send(request.message, request.timeout)
await conn.send(request.message, request.timeout, abort_event=request.abort_event)

response = await recieve_unary_response(conn=conn, t=output)
return response
Expand Down Expand Up @@ -267,9 +268,9 @@ def on_request_send(r: httpcore.Request) -> None:

conn.on_request_send(on_request_send)

await conn.send(request.messages, request.timeout)
await conn.send(request.messages, request.timeout, request.abort_event)

response = await recieve_stream_response(conn, output, request.spec)
response = await recieve_stream_response(conn, output, request.spec, request.abort_event)
return response

stream_func = apply_interceptors(_stream_func, options.interceptors)
Expand Down Expand Up @@ -299,49 +300,84 @@ async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_
"""
return await self._call_unary(request)

async def call_server_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]:
"""Asynchronously calls a server streaming RPC (Remote Procedure Call) with the given request.
@contextlib.asynccontextmanager
async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]:
"""Initiate a server-streaming RPC call and returns an asynchronous generator that yields responses from the server.

Args:
request (UnaryRequest[T_Request]): The request object containing the data to be sent to the server.
request (StreamRequest[T_Request]): The request object containing the
data to be sent to the server.

Returns:
UnaryResponse[T_Response]: The response object containing the data received from the server.
Yields:
StreamResponse[T_Response]: The response objects received from the server.

"""
return await self._call_stream(StreamType.ServerStream, request)
Raises:
Any exceptions that occur during the streaming process.

Notes:
- This method ensures that the response stream is properly closed
after the generator is exhausted or an exception occurs.
- The type parameters `T_Request` and `T_Response` represent the
request and response types, respectively.

async def call_client_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]:
"""Asynchronously calls a client stream and yields responses.
"""
response = await self._call_stream(StreamType.ServerStream, request)
try:
yield response
finally:
await response.aclose()

This method sends a stream request to the client and asynchronously
iterates over the responses, yielding each response one by one.
@contextlib.asynccontextmanager
async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]:
"""Initiate a client-streaming RPC call and returns an asynchronous generator for streaming responses from the server.

Args:
request (StreamRequest[T_Request]): The stream request to be sent.
request (StreamRequest[T_Request]): The request object containing the
client-streaming data to be sent to the server.

Yields:
UnaryResponse[T_Response]: The response from the client stream.
StreamResponse[T_Response]: An asynchronous generator that yields
responses from the server.

Raises:
Any exceptions raised during the streaming call.

Notes:
- The `response.aclose()` method is called in the `finally` block to
ensure proper cleanup of the response stream.

"""
return await self._call_stream(StreamType.ClientStream, request)
response = await self._call_stream(StreamType.ClientStream, request)
try:
yield response
finally:
await response.aclose()

async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]:
"""Initiate a bidirectional streaming call.
@contextlib.asynccontextmanager
async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]:
"""Initiate a bidirectional streaming call with the server.

This method establishes a bidirectional stream between the client and the server,
allowing both to send and receive messages asynchronously.
This method sends a stream request to the server and returns an asynchronous
generator that yields stream responses from the server. The connection is
automatically closed when the generator is exhausted or an exception occurs.

Args:
request (StreamRequest[T_Request]): The request object containing the stream
of messages to be sent to the server.
request (StreamRequest[T_Request]): The stream request object containing
the data to be sent to the server.

Returns:
StreamResponse[T_Response]: An asynchronous stream response object that
allows receiving messages from the server.
Yields:
StreamResponse[T_Response]: The stream response object received from the server.

Raises:
Any exceptions raised during the streaming call will propagate to the caller.
Any exceptions raised during the streaming call.

Notes:
Ensure to consume the generator properly to avoid resource leaks, as the
connection is closed in the `finally` block.

"""
return await self._call_stream(StreamType.BiDiStream, request)
response = await self._call_stream(StreamType.BiDiStream, request)
try:
yield response
finally:
await response.aclose()
Loading
Loading