diff --git a/src/connect/protocol_connect/connect_handler.py b/src/connect/protocol_connect/connect_handler.py index 5e29fb8..6b9e1c9 100644 --- a/src/connect/protocol_connect/connect_handler.py +++ b/src/connect/protocol_connect/connect_handler.py @@ -207,7 +207,7 @@ async def conn( ) peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, + address=Address(host=request.client.host, port=request.client.port) if request.client else None, protocol=PROTOCOL_CONNECT, query=request.query_params, ) diff --git a/src/connect/protocol_grpc/constants.py b/src/connect/protocol_grpc/constants.py index faf48b4..25bd698 100644 --- a/src/connect/protocol_grpc/constants.py +++ b/src/connect/protocol_grpc/constants.py @@ -21,8 +21,8 @@ HEADER_X_USER_AGENT = "X-User-Agent" GRPC_ALLOWED_METHODS = [HTTPMethod.POST] - -DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{__version__})" +_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" +DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{_python_version})" RE_TIMEOUT = re.compile(r"^(\d{1,8})([HMSmun])$") @@ -35,4 +35,6 @@ "H": 3600.0, } +GRPC_TIMEOUT_MAX_VALUE = 10**8 +GRPC_TIMEOUT_MAX_DURATION = 99_999_999 MAX_HOURS = sys.maxsize // (60 * 60 * 1_000_000_000) diff --git a/src/connect/protocol_grpc/error_trailer.py b/src/connect/protocol_grpc/error_trailer.py index 2319831..c6ab0b8 100644 --- a/src/connect/protocol_grpc/error_trailer.py +++ b/src/connect/protocol_grpc/error_trailer.py @@ -37,7 +37,7 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: trailer[GRPC_HEADER_STATUS] = "0" return - if not ConnectError.wire_error: + if not error.wire_error: trailer.update(exclude_protocol_headers(error.metadata)) status = status_pb2.Status( @@ -152,9 +152,6 @@ def decode_binary_header(data: str) -> bytes: Returns: bytes: The decoded binary data. - Raises: - binascii.Error: If the input is not correctly base64-encoded. - """ if len(data) % 4: data += "=" * (-len(data) % 4) diff --git a/src/connect/protocol_grpc/grpc_client.py b/src/connect/protocol_grpc/grpc_client.py index 1819b60..829b246 100644 --- a/src/connect/protocol_grpc/grpc_client.py +++ b/src/connect/protocol_grpc/grpc_client.py @@ -35,6 +35,7 @@ GRPC_HEADER_ACCEPT_COMPRESSION, GRPC_HEADER_COMPRESSION, GRPC_HEADER_TIMEOUT, + GRPC_TIMEOUT_MAX_VALUE, HEADER_X_USER_AGENT, UNIT_TO_SECONDS, ) @@ -254,7 +255,7 @@ def spec(self) -> Spec: @property def peer(self) -> Peer: """Return the peer information.""" - raise NotImplementedError() + return self._peer async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and processes it.""" @@ -284,7 +285,8 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: self.response_trailers.update(self._response_headers) - del self._response_headers[HEADER_CONTENT_TYPE] + if HEADER_CONTENT_TYPE in self._response_headers: + del self._response_headers[HEADER_CONTENT_TYPE] server_error = grpc_error_from_trailer(self.response_trailers) if server_error: @@ -294,7 +296,7 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn server_error = grpc_error_from_trailer(self.response_trailers) if server_error: server_error.metadata = self.response_headers.copy() - server_error.metadata.update(self.response_trailers) + server_error.metadata.update(self.response_trailers.copy()) raise server_error def _receive_trailers(self, response: httpcore.Response) -> None: @@ -368,20 +370,25 @@ async def send( request_task = asyncio.create_task(self.pool.handle_async_request(request=request)) abort_task = asyncio.create_task(abort_event.wait()) - done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + try: + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) - if abort_task in done: - request_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await request_task + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task - raise ConnectError("request aborted", Code.CANCELED) + raise ConnectError("request aborted", Code.CANCELED) - abort_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await abort_task + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task - response = await request_task + response = await request_task + finally: + for task in [request_task, abort_task]: + if not task.done(): + task.cancel() for hook in self._event_hooks["response"]: hook(response) @@ -466,7 +473,7 @@ def grpc_encode_timeout(timeout: float) -> str: if timeout <= 0: return "0n" - grpc_timeout_max_value = 10**8 + grpc_timeout_max_value = GRPC_TIMEOUT_MAX_VALUE _units = dict(sorted(UNIT_TO_SECONDS.items(), key=lambda item: item[1])) for unit, size in _units.items(): diff --git a/src/connect/protocol_grpc/grpc_handler.py b/src/connect/protocol_grpc/grpc_handler.py index fdd34b1..5d25b6b 100644 --- a/src/connect/protocol_grpc/grpc_handler.py +++ b/src/connect/protocol_grpc/grpc_handler.py @@ -26,6 +26,7 @@ GRPC_HEADER_ACCEPT_COMPRESSION, GRPC_HEADER_COMPRESSION, GRPC_HEADER_TIMEOUT, + GRPC_TIMEOUT_MAX_DURATION, MAX_HOURS, RE_TIMEOUT, UNIT_TO_SECONDS, @@ -145,7 +146,7 @@ async def conn( protocol_name = PROTOCOL_GRPC if not self.web else PROTOCOL_GRPC + "-web" peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, + address=Address(host=request.client.host, port=request.client.port) if request.client else None, protocol=protocol_name, query=request.query_params, ) @@ -267,7 +268,7 @@ def parse_timeout(self) -> float | None: num_str, unit = m.groups() num = int(num_str) - if num > 99_999_999: + if num > GRPC_TIMEOUT_MAX_DURATION: raise ConnectError(f"protocol error: timeout {timeout!r} is too long") if unit == "H" and num > MAX_HOURS: