| 
34 | 34 | )  | 
35 | 35 | 
 
  | 
36 | 36 | from bson import DEFAULT_CODEC_OPTIONS  | 
37 |  | -from pymongo import _csot, helpers_shared  | 
 | 37 | +from pymongo import _csot, helpers_shared, network_layer  | 
38 | 38 | from pymongo.asynchronous.client_session import _validate_session_write_concern  | 
39 | 39 | from pymongo.asynchronous.helpers import _handle_reauth  | 
40 | 40 | from pymongo.asynchronous.network import command  | 
@@ -188,6 +188,41 @@ def __init__(  | 
188 | 188 |         self.creation_time = time.monotonic()  | 
189 | 189 |         # For gossiping $clusterTime from the connection handshake to the client.  | 
190 | 190 |         self._cluster_time = None  | 
 | 191 | +        self.pending_response = False  | 
 | 192 | +        self.pending_bytes = 0  | 
 | 193 | +        self.pending_deadline = 0.0  | 
 | 194 | + | 
 | 195 | +    def mark_pending(self, nbytes: int) -> None:  | 
 | 196 | +        """Mark this connection as having a pending response."""  | 
 | 197 | +        # TODO: add "if self.enable_pending:"  | 
 | 198 | +        self.pending_response = True  | 
 | 199 | +        self.pending_bytes = nbytes  | 
 | 200 | +        self.pending_deadline = time.monotonic() + 3  # 3 seconds timeout for pending response  | 
 | 201 | + | 
 | 202 | +    async def complete_pending(self) -> None:  | 
 | 203 | +        """Complete a pending response."""  | 
 | 204 | +        if not self.pending_response:  | 
 | 205 | +            return  | 
 | 206 | + | 
 | 207 | +        timeout: Optional[Union[float, int]]  | 
 | 208 | +        timeout = self.conn.gettimeout  | 
 | 209 | +        if _csot.get_timeout():  | 
 | 210 | +            deadline = min(_csot.get_deadline(), self.pending_deadline)  | 
 | 211 | +        elif timeout:  | 
 | 212 | +            deadline = min(time.monotonic() + timeout, self.pending_deadline)  | 
 | 213 | +        else:  | 
 | 214 | +            deadline = self.pending_deadline  | 
 | 215 | + | 
 | 216 | +        if not _IS_SYNC:  | 
 | 217 | +            # In async the reader task reads the whole message at once.  | 
 | 218 | +            # TODO: respect deadline  | 
 | 219 | +            await self.receive_message(None, True)  | 
 | 220 | +        else:  | 
 | 221 | +            # In sync we need to track the bytes left for the message.  | 
 | 222 | +            network_layer.receive_data(self.conn.get_conn, self.pending_byte, deadline)  | 
 | 223 | +        self.pending_response = False  | 
 | 224 | +        self.pending_bytes = 0  | 
 | 225 | +        self.pending_deadline = 0.0  | 
191 | 226 | 
 
  | 
192 | 227 |     def set_conn_timeout(self, timeout: Optional[float]) -> None:  | 
193 | 228 |         """Cache last timeout to avoid duplicate calls to conn.settimeout."""  | 
@@ -454,13 +489,17 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:  | 
454 | 489 |         except BaseException as error:  | 
455 | 490 |             await self._raise_connection_failure(error)  | 
456 | 491 | 
 
  | 
457 |  | -    async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:  | 
 | 492 | +    async def receive_message(  | 
 | 493 | +        self, request_id: Optional[int], enable_pending: bool = False  | 
 | 494 | +    ) -> Union[_OpReply, _OpMsg]:  | 
458 | 495 |         """Receive a raw BSON message or raise ConnectionFailure.  | 
459 | 496 | 
  | 
460 | 497 |         If any exception is raised, the socket is closed.  | 
461 | 498 |         """  | 
462 | 499 |         try:  | 
463 |  | -            return await async_receive_message(self, request_id, self.max_message_size)  | 
 | 500 | +            return await async_receive_message(  | 
 | 501 | +                self, request_id, self.max_message_size, enable_pending  | 
 | 502 | +            )  | 
464 | 503 |         # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.  | 
465 | 504 |         except BaseException as error:  | 
466 | 505 |             await self._raise_connection_failure(error)  | 
@@ -495,7 +534,9 @@ async def write_command(  | 
495 | 534 |         :param msg: bytes, the command message.  | 
496 | 535 |         """  | 
497 | 536 |         await self.send_message(msg, 0)  | 
498 |  | -        reply = await self.receive_message(request_id)  | 
 | 537 | +        reply = await self.receive_message(  | 
 | 538 | +            request_id, enable_pending=(_csot.get_timeout() is not None)  | 
 | 539 | +        )  | 
499 | 540 |         result = reply.command_response(codec_options)  | 
500 | 541 | 
 
  | 
501 | 542 |         # Raises NotPrimaryError or OperationFailure.  | 
@@ -635,7 +676,10 @@ async def _raise_connection_failure(self, error: BaseException) -> NoReturn:  | 
635 | 676 |             reason = None  | 
636 | 677 |         else:  | 
637 | 678 |             reason = ConnectionClosedReason.ERROR  | 
638 |  | -        await self.close_conn(reason)  | 
 | 679 | + | 
 | 680 | +        # Pending connections should be placed back in the pool.  | 
 | 681 | +        if not self.pending_response:  | 
 | 682 | +            await self.close_conn(reason)  | 
639 | 683 |         # SSLError from PyOpenSSL inherits directly from Exception.  | 
640 | 684 |         if isinstance(error, (IOError, OSError, SSLError)):  | 
641 | 685 |             details = _get_timeout_details(self.opts)  | 
@@ -1076,7 +1120,7 @@ async def checkout(  | 
1076 | 1120 | 
  | 
1077 | 1121 |         This method should always be used in a with-statement::  | 
1078 | 1122 | 
  | 
1079 |  | -            with pool.get_conn() as connection:  | 
 | 1123 | +            with pool.checkout() as connection:  | 
1080 | 1124 |                 connection.send_message(msg)  | 
1081 | 1125 |                 data = connection.receive_message(op_code, request_id)  | 
1082 | 1126 | 
  | 
@@ -1388,6 +1432,7 @@ async def _perished(self, conn: AsyncConnection) -> bool:  | 
1388 | 1432 |         pool, to keep performance reasonable - we can't avoid AutoReconnects  | 
1389 | 1433 |         completely anyway.  | 
1390 | 1434 |         """  | 
 | 1435 | +        await conn.complete_pending()  | 
1391 | 1436 |         idle_time_seconds = conn.idle_time_seconds()  | 
1392 | 1437 |         # If socket is idle, open a new one.  | 
1393 | 1438 |         if (  | 
 | 
0 commit comments