Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmd/protoc-gen-connect-python/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`)
p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`)
p.P(`from connect.options import ClientOptions, ConnectOptions`)
p.P(`from connect.session import AsyncClientSession`)
p.P(`from connect.connection_pool import AsyncConnectionPool`)
p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`)
p.P()

Expand Down Expand Up @@ -238,13 +238,13 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
p.P()
p.P()
p.P(`class `, upperSvcName, `Client:`)
p.P(` def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:`)
p.P(` def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None:`)
p.P(` base_url = base_url.removesuffix("/")`)
p.P()
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`)
p.P(` `, `session, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
switch meth.RPCType {
case Unary:
p.P(` `, `).call_unary`)
Expand Down
8 changes: 4 additions & 4 deletions conformance/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from typing import Any

from connect.connect import StreamRequest, UnaryRequest
from connect.connection_pool import AsyncConnectionPool
from connect.error import ConnectError
from connect.headers import Headers
from connect.options import ClientOptions
from connect.session import AsyncClientSession
from google.protobuf import any_pb2
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer

Expand Down Expand Up @@ -182,7 +182,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c
- Captures and returns errors in the response if exceptions occur.

Note:
- This function uses an asynchronous HTTP client session (`AsyncClientSession`)
- This function uses an asynchronous HTTP client connection pool (`AsyncConnectionPool`)
for making requests.
- Compression (e.g., gzip) is applied if specified in the request.
- Headers and trailers are converted to protobuf-compatible formats.
Expand Down Expand Up @@ -215,7 +215,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c

url = f"{proto}://{msg.host}:{msg.port}"

async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session:
async with AsyncConnectionPool(http1=http1, http2=http2, ssl_context=ssl_context) as pool:
payloads = []
try:
options = ClientOptions()
Expand All @@ -231,7 +231,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c
if msg.codec == config_pb2.CODEC_JSON:
options.use_binary_format = False

client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options)
client = service_connect.ConformanceServiceClient(base_url=url, pool=pool, options=options)
if msg.stream_type == config_pb2.STREAM_TYPE_UNARY:
if msg.request_delay_ms > 0:
await asyncio.sleep(msg.request_delay_ms / 1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import connect.connect
from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler, BidiStreamHandler
from connect.options import ClientOptions, ConnectOptions
from connect.session import AsyncClientSession
from connect.connection_pool import AsyncConnectionPool
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
from connect.idempotency_level import IdempotencyLevel

Expand Down Expand Up @@ -40,26 +40,26 @@ class ConformanceServiceProcedures(Enum):


class ConformanceServiceClient:
def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:
def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None:
base_url = base_url.removesuffix("/")

self.Unary = Client[UnaryRequest, UnaryResponse](
session, base_url + ConformanceServiceProcedures.Unary.value, UnaryRequest, UnaryResponse, options
pool, base_url + ConformanceServiceProcedures.Unary.value, UnaryRequest, UnaryResponse, options
).call_unary
self.ServerStream = Client[ServerStreamRequest, ServerStreamResponse](
session, base_url + ConformanceServiceProcedures.ServerStream.value, ServerStreamRequest, ServerStreamResponse, options
pool, base_url + ConformanceServiceProcedures.ServerStream.value, ServerStreamRequest, ServerStreamResponse, options
).call_server_stream
self.ClientStream = Client[ClientStreamRequest, ClientStreamResponse](
session, base_url + ConformanceServiceProcedures.ClientStream.value, ClientStreamRequest, ClientStreamResponse, options
pool, base_url + ConformanceServiceProcedures.ClientStream.value, ClientStreamRequest, ClientStreamResponse, options
).call_client_stream
self.BidiStream = Client[BidiStreamRequest, BidiStreamResponse](
session, base_url + ConformanceServiceProcedures.BidiStream.value, BidiStreamRequest, BidiStreamResponse, options
pool, base_url + ConformanceServiceProcedures.BidiStream.value, BidiStreamRequest, BidiStreamResponse, options
).call_bidi_stream
self.Unimplemented = Client[UnimplementedRequest, UnimplementedResponse](
session, base_url + ConformanceServiceProcedures.Unimplemented.value, UnimplementedRequest, UnimplementedResponse, options
pool, base_url + ConformanceServiceProcedures.Unimplemented.value, UnimplementedRequest, UnimplementedResponse, options
).call_unary
self.IdempotentUnary = Client[IdempotentUnaryRequest, IdempotentUnaryResponse](
session, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options),
pool, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options),
).call_unary


Expand Down
6 changes: 3 additions & 3 deletions examples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from connect.connect import UnaryRequest
from connect.session import AsyncClientSession
from connect.connection_pool import AsyncConnectionPool

from proto.connectrpc.eliza.v1.eliza_pb2 import SayRequest
from proto.connectrpc.eliza.v1.v1connect.eliza_connect_pb2 import ElizaServiceClient
Expand All @@ -15,9 +15,9 @@

async def main() -> None:
"""Interact with the ElizaServiceClient asynchronously."""
async with AsyncClientSession() as session:
async with AsyncConnectionPool() as pool:
client = ElizaServiceClient(
session=session,
pool=pool,
base_url="http://localhost:8080/",
)
response = await client.Say(UnaryRequest(SayRequest(sentence="I feel happy.")))
Expand Down
14 changes: 7 additions & 7 deletions examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

from connect.client import Client
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse
from connect.connection_pool import AsyncConnectionPool
from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler
from connect.options import ClientOptions, ConnectOptions
from connect.session import AsyncClientSession
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor

from .. import eliza_pb2
from ..eliza_pb2 import SayRequest, SayResponse, ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse
from ..eliza_pb2 import ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse, SayRequest, SayResponse


class ElizaServiceProcedures(Enum):
Expand All @@ -35,20 +35,20 @@ class ElizaServiceProcedures(Enum):


class ElizaServiceClient:
def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:
def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None:
base_url = base_url.removesuffix("/")

self.Say = Client[SayRequest, SayResponse](
session, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
).call_unary
self.Converse = Client[ConverseRequest, ConverseResponse](
session, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options
pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options
).call_server_stream
self.IntroduceServer = Client[IntroduceRequest, IntroduceResponse](
session, base_url + ElizaServiceProcedures.IntroduceServer.value, IntroduceRequest, IntroduceResponse, options
pool, base_url + ElizaServiceProcedures.IntroduceServer.value, IntroduceRequest, IntroduceResponse, options
).call_server_stream
self.IntroduceClient = Client[IntroduceRequest, IntroduceResponse](
session, base_url + ElizaServiceProcedures.IntroduceClient.value, IntroduceRequest, IntroduceResponse, options
pool, base_url + ElizaServiceProcedures.IntroduceClient.value, IntroduceRequest, IntroduceResponse, options
).call_client_stream


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from connect.client import Client
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse
from connect.connection_pool import AsyncConnectionPool
from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler
from connect.options import ClientOptions, ConnectOptions
from connect.session import AsyncClientSession
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor

from .. import eliza_pb2
Expand All @@ -38,21 +38,21 @@ class ElizaServiceProcedures(Enum):


class ElizaServiceClient:
def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:
def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None:
base_url = base_url.removesuffix("/")

self.Say = Client[SayRequest, SayResponse](
session, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
).call_unary
self.IntroduceServer = Client[IntroduceRequest, IntroduceResponse](
session,
pool,
base_url + ElizaServiceProcedures.IntroduceServer.value,
IntroduceRequest,
IntroduceResponse,
options,
).call_server_stream
self.IntroduceClient = Client[IntroduceRequest, IntroduceResponse](
session,
pool,
base_url + ElizaServiceProcedures.IntroduceClient.value,
IntroduceRequest,
IntroduceResponse,
Expand Down
8 changes: 4 additions & 4 deletions src/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
recieve_stream_response,
recieve_unary_response,
)
from connect.connection_pool import AsyncConnectionPool
from connect.error import ConnectError
from connect.idempotency_level import IdempotencyLevel
from connect.interceptor import apply_interceptors
from connect.options import ClientOptions
from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams
from connect.protocol_connect.connect_protocol import ProtocolConnect
from connect.protocol_grpc.grpc_protocol import ProtocolGRPC
from connect.session import AsyncClientSession
from connect.utils import aiterate


Expand Down Expand Up @@ -186,7 +186,7 @@ class Client[T_Request, T_Response]:

def __init__(
self,
session: AsyncClientSession,
pool: AsyncConnectionPool,
url: str,
input: type[T_Request],
output: type[T_Response],
Expand All @@ -195,7 +195,7 @@ def __init__(
"""Initialize the client with the given URL, request and response types, and optional client options.

Args:
session (AsyncClientSession): The client session to use for the connection.
pool (AsyncConnectionPool): The connection pool to use for making requests.
url (str): The URL of the server to connect to.
input (type[T_Request]): The type of the request object.
output (type[T_Response]): The type of the response object.
Expand All @@ -212,7 +212,7 @@ def __init__(

protocol_client = config.protocol.client(
ProtocolClientParams(
session=session,
pool=pool,
codec=config.codec,
url=config.url,
compression_name=config.request_compression_name,
Expand Down
3 changes: 3 additions & 0 deletions src/connect/connection_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Provides connection pool functionality using httpcore's AsyncConnectionPool."""

from httpcore import AsyncConnectionPool as AsyncConnectionPool
4 changes: 2 additions & 2 deletions src/connect/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
StreamingHandlerConn,
StreamType,
)
from connect.connection_pool import AsyncConnectionPool
from connect.error import ConnectError
from connect.headers import Headers
from connect.idempotency_level import IdempotencyLevel
from connect.request import Request
from connect.response_writer import ServerResponseWriter
from connect.session import AsyncClientSession

PROTOCOL_CONNECT = "connect"
PROTOCOL_GRPC = "grpc"
Expand Down Expand Up @@ -70,7 +70,7 @@ class ProtocolClientParams(BaseModel):
arbitrary_types_allowed=True,
)

session: AsyncClientSession
pool: AsyncConnectionPool
codec: Codec
url: URL
compression_name: str | None = Field(default=None)
Expand Down
28 changes: 14 additions & 14 deletions src/connect/protocol_connect/connect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StreamType,
ensure_single,
)
from connect.connection_pool import AsyncConnectionPool
from connect.content_stream import BoundAsyncStream
from connect.error import ConnectError
from connect.headers import Headers, include_request_headers
Expand Down Expand Up @@ -59,7 +60,6 @@
from connect.protocol_connect.error_json import error_from_json
from connect.protocol_connect.marshaler import ConnectStreamingMarshaler, ConnectUnaryRequestMarshaler
from connect.protocol_connect.unmarshaler import ConnectStreamingUnmarshaler, ConnectUnaryUnmarshaler
from connect.session import AsyncClientSession
from connect.utils import (
map_httpcore_exceptions,
)
Expand Down Expand Up @@ -147,7 +147,7 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn:
conn: StreamingClientConn
if spec.stream_type == StreamType.Unary:
conn = ConnectUnaryClientConn(
session=self.params.session,
pool=self.params.pool,
spec=spec,
peer=self.peer,
url=self.params.url,
Expand All @@ -172,7 +172,7 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn:
conn.marshaler.stable_codec = self.params.codec
else:
conn = ConnectStreamingClientConn(
session=self.params.session,
pool=self.params.pool,
spec=spec,
peer=self.peer,
url=self.params.url,
Expand Down Expand Up @@ -212,7 +212,7 @@ class ConnectUnaryClientConn(StreamingClientConn):

"""

session: AsyncClientSession
pool: AsyncConnectionPool
_spec: Spec
_peer: Peer
url: URL
Expand All @@ -226,7 +226,7 @@ class ConnectUnaryClientConn(StreamingClientConn):

def __init__(
self,
session: AsyncClientSession,
pool: AsyncConnectionPool,
spec: Spec,
peer: Peer,
url: URL,
Expand All @@ -239,7 +239,7 @@ def __init__(
"""Initialize the ConnectProtocol instance.

Args:
session (AsyncClientSession): The session for the connection.
pool (AsyncConnectionPool): The connection pool for the client.
spec (Spec): The specification for the connection.
peer (Peer): The peer information.
url (URL): The URL for the connection.
Expand All @@ -255,7 +255,7 @@ def __init__(
"""
event_hooks = {} if event_hooks is None else event_hooks

self.session = session
self.pool = pool
self._spec = spec
self._peer = peer
self.url = url
Expand Down Expand Up @@ -412,9 +412,9 @@ async def send(

with map_httpcore_exceptions():
if not abort_event:
response = await self.session.pool.handle_async_request(request=request)
response = await self.pool.handle_async_request(request=request)
else:
request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request))
request_task = asyncio.create_task(self.pool.handle_async_request(request=request))
abort_task = asyncio.create_task(abort_event.wait())

done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
Expand Down Expand Up @@ -563,7 +563,7 @@ class ConnectStreamingClientConn(StreamingClientConn):

def __init__(
self,
session: AsyncClientSession,
pool: AsyncConnectionPool,
spec: Spec,
peer: Peer,
url: URL,
Expand All @@ -577,7 +577,7 @@ def __init__(
"""Initialize a new instance of the class.

Args:
session (AsyncClientSession): The session object for the connection.
pool (AsyncConnectionPool): The connection pool for the client.
spec (Spec): The specification object.
peer (Peer): The peer object.
url (URL): The URL for the connection.
Expand All @@ -594,7 +594,7 @@ def __init__(
"""
event_hooks = {} if event_hooks is None else event_hooks

self.session = session
self.pool = pool
self._spec = spec
self._peer = peer
self.url = url
Expand Down Expand Up @@ -771,9 +771,9 @@ async def send(

with map_httpcore_exceptions():
if not abort_event:
response = await self.session.pool.handle_async_request(request)
response = await self.pool.handle_async_request(request)
else:
request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request))
request_task = asyncio.create_task(self.pool.handle_async_request(request=request))
abort_task = asyncio.create_task(abort_event.wait())

done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
Expand Down
Loading
Loading