diff --git a/pymongo/asynchronous/aggregation.py b/pymongo/asynchronous/aggregation.py index 059d698772..6ca60ad9c3 100644 --- a/pymongo/asynchronous/aggregation.py +++ b/pymongo/asynchronous/aggregation.py @@ -50,7 +50,6 @@ def __init__( cursor_class: type[AsyncCommandCursor[Any]], pipeline: _Pipeline, options: MutableMapping[str, Any], - explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None, @@ -92,7 +91,6 @@ def __init__( self._options["cursor"]["batchSize"] = self._batch_size self._cursor_class = cursor_class - self._explicit_session = explicit_session self._user_fields = user_fields self._result_processor = result_processor @@ -197,7 +195,6 @@ async def get_cursor( batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, - explicit_session=self._explicit_session, comment=self._options.get("comment"), ) await cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 3940111df2..b2b78b0660 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -236,7 +236,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N ) async def _run_aggregation_cmd( - self, session: Optional[AsyncClientSession], explicit_session: bool + self, session: Optional[AsyncClientSession] ) -> AsyncCommandCursor: # type: ignore[type-arg] """Run the full aggregation pipeline for this AsyncChangeStream and return the corresponding AsyncCommandCursor. @@ -246,7 +246,6 @@ async def _run_aggregation_cmd( AsyncCommandCursor, self._aggregation_pipeline(), self._command_options(), - explicit_session, result_processor=self._process_result, comment=self._comment, ) @@ -258,10 +257,8 @@ async def _run_aggregation_cmd( ) async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg] - async with self._client._tmp_session(self._session, close=False) as s: - return await self._run_aggregation_cmd( - session=s, explicit_session=self._session is not None - ) + async with self._client._tmp_session(self._session) as s: + return await self._run_aggregation_cmd(session=s) async def _resume(self) -> None: """Reestablish this change stream after a resumable error.""" diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 45812b3400..151942c8a8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -440,6 +440,8 @@ async def _process_results_cursor( ) -> None: """Internal helper for processing the server reply command cursor.""" if result.get("cursor"): + if session: + session._leave_alive = True coll = AsyncCollection( database=AsyncDatabase(self.client, "admin"), name="$cmd.bulkWrite", @@ -449,7 +451,6 @@ async def _process_results_cursor( result["cursor"], conn.address, session=session, - explicit_session=session is not None, comment=self.comment, ) await cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index be02295cea..8674e98447 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -513,6 +513,10 @@ def __init__( # Is this an implicitly created session? self._implicit = implicit self._transaction = _Transaction(None, client) + # Is this session attached to a cursor? + self._attached_to_cursor = False + # Should we leave the session alive when the cursor is closed? + self._leave_alive = False async def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -535,7 +539,7 @@ async def _end_session(self, lock: bool) -> None: def _end_implicit_session(self) -> None: # Implicit sessions can't be part of transactions or pinned connections - if self._server_session is not None: + if not self._leave_alive and self._server_session is not None: self._client._return_server_session(self._server_session) self._server_session = None diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 064231ccfc..6af1f4f782 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -2549,7 +2549,6 @@ async def _list_indexes( self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY), ) read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - explicit_session = session is not None async def _cmd( session: Optional[AsyncClientSession], @@ -2576,13 +2575,12 @@ async def _cmd( cursor, conn.address, session=session, - explicit_session=explicit_session, comment=cmd.get("comment"), ) await cmd_cursor._maybe_pin_connection(conn) return cmd_cursor - async with self._database.client._tmp_session(session, False) as s: + async with self._database.client._tmp_session(session) as s: return await self._database.client._retryable_read( _cmd, read_pref, s, operation=_Op.LIST_INDEXES ) @@ -2678,7 +2676,6 @@ async def list_search_indexes( AsyncCommandCursor, pipeline, kwargs, - explicit_session=session is not None, comment=comment, user_fields={"cursor": {"firstBatch": 1}}, ) @@ -2900,7 +2897,6 @@ async def _aggregate( pipeline: _Pipeline, cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg] session: Optional[AsyncClientSession], - explicit_session: bool, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -2912,7 +2908,6 @@ async def _aggregate( cursor_class, pipeline, kwargs, - explicit_session, let, user_fields={"cursor": {"firstBatch": 1}}, ) @@ -3018,13 +3013,12 @@ async def aggregate( .. _aggregate command: https://mongodb.com/docs/manual/reference/command/aggregate """ - async with self._database.client._tmp_session(session, close=False) as s: + async with self._database.client._tmp_session(session) as s: return await self._aggregate( _CollectionAggregationCommand, pipeline, AsyncCommandCursor, session=s, - explicit_session=session is not None, let=let, comment=comment, **kwargs, @@ -3065,7 +3059,7 @@ async def aggregate_raw_batches( raise InvalidOperation("aggregate_raw_batches does not support auto encryption") if comment is not None: kwargs["comment"] = comment - async with self._database.client._tmp_session(session, close=False) as s: + async with self._database.client._tmp_session(session) as s: return cast( AsyncRawBatchCursor[_DocumentType], await self._aggregate( @@ -3073,7 +3067,6 @@ async def aggregate_raw_batches( pipeline, AsyncRawBatchCommandCursor, session=s, - explicit_session=session is not None, **kwargs, ), ) diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index db7c2b6638..e18b3a330e 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -64,7 +64,6 @@ def __init__( batch_size: int = 0, max_await_time_ms: Optional[int] = None, session: Optional[AsyncClientSession] = None, - explicit_session: bool = False, comment: Any = None, ) -> None: """Create a new command cursor.""" @@ -80,7 +79,8 @@ def __init__( self._max_await_time_ms = max_await_time_ms self._timeout = self._collection.database.client.options.timeout self._session = session - self._explicit_session = explicit_session + if self._session is not None: + self._session._attached_to_cursor = True self._killed = self._id == 0 self._comment = comment if self._killed: @@ -197,7 +197,7 @@ def session(self) -> Optional[AsyncClientSession]: .. versionadded:: 3.6 """ - if self._explicit_session: + if self._session and not self._session._implicit: return self._session return None @@ -218,9 +218,10 @@ def _die_no_lock(self) -> None: """Closes this cursor without acquiring a lock.""" cursor_id, address = self._prepare_to_die() self._collection.database.client._cleanup_cursor_no_lock( - cursor_id, address, self._sock_mgr, self._session, self._explicit_session + cursor_id, address, self._sock_mgr, self._session ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None @@ -232,14 +233,15 @@ async def _die_lock(self) -> None: address, self._sock_mgr, self._session, - self._explicit_session, ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None def _end_session(self) -> None: - if self._session and not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session._end_implicit_session() self._session = None @@ -430,7 +432,6 @@ def __init__( batch_size: int = 0, max_await_time_ms: Optional[int] = None, session: Optional[AsyncClientSession] = None, - explicit_session: bool = False, comment: Any = None, ) -> None: """Create a new cursor / iterator over raw batches of BSON data. @@ -449,7 +450,6 @@ def __init__( batch_size, max_await_time_ms, session, - explicit_session, comment, ) diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index d9fdd576f4..df060a4fa9 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -138,10 +138,9 @@ def __init__( if session: self._session = session - self._explicit_session = True + self._session._attached_to_cursor = True else: self._session = None - self._explicit_session = False spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) @@ -150,7 +149,7 @@ def __init__( if not isinstance(limit, int): raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) - if no_cursor_timeout and not self._explicit_session: + if no_cursor_timeout and self._session and self._session._implicit: warnings.warn( "use an explicit session with no_cursor_timeout=True " "otherwise the cursor may still timeout after " @@ -283,7 +282,7 @@ def clone(self) -> AsyncCursor[_DocumentType]: def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg] """Internal clone helper.""" if not base: - if self._explicit_session: + if self._session and not self._session._implicit: base = self._clone_base(self._session) else: base = self._clone_base(None) @@ -945,7 +944,7 @@ def session(self) -> Optional[AsyncClientSession]: .. versionadded:: 3.6 """ - if self._explicit_session: + if self._session and not self._session._implicit: return self._session return None @@ -1034,9 +1033,10 @@ def _die_no_lock(self) -> None: cursor_id, address = self._prepare_to_die(already_killed) self._collection.database.client._cleanup_cursor_no_lock( - cursor_id, address, self._sock_mgr, self._session, self._explicit_session + cursor_id, address, self._sock_mgr, self._session ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None @@ -1054,9 +1054,9 @@ async def _die_lock(self) -> None: address, self._sock_mgr, self._session, - self._explicit_session, ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index f70c2b403f..8e0afc9dc9 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -611,6 +611,8 @@ async def create_collection( common.validate_is_mapping("clusteredIndex", clustered_index) async with self._client._tmp_session(session) as s: + if s and not s.in_transaction: + s._leave_alive = True # Skip this check in a transaction where listCollections is not # supported. if ( @@ -619,6 +621,8 @@ async def create_collection( and name in await self._list_collection_names(filter={"name": name}, session=s) ): raise CollectionInvalid("collection %s already exists" % name) + if s: + s._leave_alive = False coll = AsyncCollection( self, name, @@ -699,13 +703,12 @@ async def aggregate( .. _aggregate command: https://mongodb.com/docs/manual/reference/command/aggregate """ - async with self.client._tmp_session(session, close=False) as s: + async with self.client._tmp_session(session) as s: cmd = _DatabaseAggregationCommand( self, AsyncCommandCursor, pipeline, kwargs, - session is not None, user_fields={"cursor": {"firstBatch": 1}}, ) return await self.client._retryable_read( @@ -1011,7 +1014,7 @@ async def cursor_command( else: command_name = next(iter(command)) - async with self._client._tmp_session(session, close=False) as tmp_session: + async with self._client._tmp_session(session) as tmp_session: opts = codec_options or DEFAULT_CODEC_OPTIONS if read_preference is None: @@ -1043,7 +1046,6 @@ async def cursor_command( conn.address, max_await_time_ms=max_await_time_ms, session=tmp_session, - explicit_session=session is not None, comment=comment, ) await cmd_cursor._maybe_pin_connection(conn) @@ -1089,7 +1091,7 @@ async def _list_collections( ) cmd = {"listCollections": 1, "cursor": {}} cmd.update(kwargs) - async with self._client._tmp_session(session, close=False) as tmp_session: + async with self._client._tmp_session(session) as tmp_session: cursor = ( await self._command(conn, cmd, read_preference=read_preference, session=tmp_session) )["cursor"] @@ -1098,7 +1100,6 @@ async def _list_collections( cursor, conn.address, session=tmp_session, - explicit_session=session is not None, comment=cmd.get("comment"), ) await cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b616647791..d9bf808d55 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2048,17 +2048,18 @@ async def _retryable_read( retryable = bool( retryable and self.options.retry_reads and not (session and session.in_transaction) ) - return await self._retry_internal( - func, - session, - None, - operation, - is_read=True, - address=address, - read_pref=read_pref, - retryable=retryable, - operation_id=operation_id, - ) + async with self._tmp_session(session) as s: + return await self._retry_internal( + func, + s, + None, + operation, + is_read=True, + address=address, + read_pref=read_pref, + retryable=retryable, + operation_id=operation_id, + ) async def _retryable_write( self, @@ -2091,7 +2092,6 @@ def _cleanup_cursor_no_lock( address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, session: Optional[AsyncClientSession], - explicit_session: bool, ) -> None: """Cleanup a cursor from __del__ without locking. @@ -2106,7 +2106,7 @@ def _cleanup_cursor_no_lock( # The cursor will be closed later in a different session. if cursor_id or conn_mgr: self._close_cursor_soon(cursor_id, address, conn_mgr) - if session and not explicit_session: + if session and session._implicit and not session._leave_alive: session._end_implicit_session() async def _cleanup_cursor_lock( @@ -2115,7 +2115,6 @@ async def _cleanup_cursor_lock( address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, session: Optional[AsyncClientSession], - explicit_session: bool, ) -> None: """Cleanup a cursor from cursor.close() using a lock. @@ -2127,7 +2126,6 @@ async def _cleanup_cursor_lock( :param address: The _CursorAddress. :param conn_mgr: The _ConnectionManager for the pinned connection or None. :param session: The cursor's session. - :param explicit_session: True if the session was passed explicitly. """ if cursor_id: if conn_mgr and conn_mgr.more_to_come: @@ -2140,7 +2138,7 @@ async def _cleanup_cursor_lock( await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) if conn_mgr: await conn_mgr.close() - if session and not explicit_session: + if session and session._implicit and not session._leave_alive: session._end_implicit_session() async def _close_cursor_now( @@ -2221,7 +2219,7 @@ async def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: - await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None) except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2266,7 +2264,7 @@ def _return_server_session( @contextlib.asynccontextmanager async def _tmp_session( - self, session: Optional[client_session.AsyncClientSession], close: bool = True + self, session: Optional[client_session.AsyncClientSession] ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]: """If provided session is None, lend a temporary session.""" if session is not None: @@ -2291,7 +2289,7 @@ async def _tmp_session( raise finally: # Call end_session when we exit this scope. - if close: + if not s._attached_to_cursor: await s.end_session() else: yield None diff --git a/pymongo/synchronous/aggregation.py b/pymongo/synchronous/aggregation.py index 9845f28b08..486768ab7d 100644 --- a/pymongo/synchronous/aggregation.py +++ b/pymongo/synchronous/aggregation.py @@ -50,7 +50,6 @@ def __init__( cursor_class: type[CommandCursor[Any]], pipeline: _Pipeline, options: MutableMapping[str, Any], - explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, @@ -92,7 +91,6 @@ def __init__( self._options["cursor"]["batchSize"] = self._batch_size self._cursor_class = cursor_class - self._explicit_session = explicit_session self._user_fields = user_fields self._result_processor = result_processor @@ -197,7 +195,6 @@ def get_cursor( batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, - explicit_session=self._explicit_session, comment=self._options.get("comment"), ) cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index f5f6352186..7e34d7b848 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -235,9 +235,7 @@ def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: f"response : {result!r}" ) - def _run_aggregation_cmd( - self, session: Optional[ClientSession], explicit_session: bool - ) -> CommandCursor: # type: ignore[type-arg] + def _run_aggregation_cmd(self, session: Optional[ClientSession]) -> CommandCursor: # type: ignore[type-arg] """Run the full aggregation pipeline for this ChangeStream and return the corresponding CommandCursor. """ @@ -246,7 +244,6 @@ def _run_aggregation_cmd( CommandCursor, self._aggregation_pipeline(), self._command_options(), - explicit_session, result_processor=self._process_result, comment=self._comment, ) @@ -258,8 +255,8 @@ def _run_aggregation_cmd( ) def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg] - with self._client._tmp_session(self._session, close=False) as s: - return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None) + with self._client._tmp_session(self._session) as s: + return self._run_aggregation_cmd(session=s) def _resume(self) -> None: """Reestablish this change stream after a resumable error.""" diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 1076ceba99..a606d028e1 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -438,6 +438,8 @@ def _process_results_cursor( ) -> None: """Internal helper for processing the server reply command cursor.""" if result.get("cursor"): + if session: + session._leave_alive = True coll = Collection( database=Database(self.client, "admin"), name="$cmd.bulkWrite", @@ -447,7 +449,6 @@ def _process_results_cursor( result["cursor"], conn.address, session=session, - explicit_session=session is not None, comment=self.comment, ) cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 72a5b8e885..9b547dc946 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -512,6 +512,10 @@ def __init__( # Is this an implicitly created session? self._implicit = implicit self._transaction = _Transaction(None, client) + # Is this session attached to a cursor? + self._attached_to_cursor = False + # Should we leave the session alive when the cursor is closed? + self._leave_alive = False def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -534,7 +538,7 @@ def _end_session(self, lock: bool) -> None: def _end_implicit_session(self) -> None: # Implicit sessions can't be part of transactions or pinned connections - if self._server_session is not None: + if not self._leave_alive and self._server_session is not None: self._client._return_server_session(self._server_session) self._server_session = None diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index e5cc816cd3..b68e4befed 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -2546,7 +2546,6 @@ def _list_indexes( self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY), ) read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - explicit_session = session is not None def _cmd( session: Optional[ClientSession], @@ -2573,13 +2572,12 @@ def _cmd( cursor, conn.address, session=session, - explicit_session=explicit_session, comment=cmd.get("comment"), ) cmd_cursor._maybe_pin_connection(conn) return cmd_cursor - with self._database.client._tmp_session(session, False) as s: + with self._database.client._tmp_session(session) as s: return self._database.client._retryable_read( _cmd, read_pref, s, operation=_Op.LIST_INDEXES ) @@ -2675,7 +2673,6 @@ def list_search_indexes( CommandCursor, pipeline, kwargs, - explicit_session=session is not None, comment=comment, user_fields={"cursor": {"firstBatch": 1}}, ) @@ -2893,7 +2890,6 @@ def _aggregate( pipeline: _Pipeline, cursor_class: Type[CommandCursor], # type: ignore[type-arg] session: Optional[ClientSession], - explicit_session: bool, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -2905,7 +2901,6 @@ def _aggregate( cursor_class, pipeline, kwargs, - explicit_session, let, user_fields={"cursor": {"firstBatch": 1}}, ) @@ -3011,13 +3006,12 @@ def aggregate( .. _aggregate command: https://mongodb.com/docs/manual/reference/command/aggregate """ - with self._database.client._tmp_session(session, close=False) as s: + with self._database.client._tmp_session(session) as s: return self._aggregate( _CollectionAggregationCommand, pipeline, CommandCursor, session=s, - explicit_session=session is not None, let=let, comment=comment, **kwargs, @@ -3058,7 +3052,7 @@ def aggregate_raw_batches( raise InvalidOperation("aggregate_raw_batches does not support auto encryption") if comment is not None: kwargs["comment"] = comment - with self._database.client._tmp_session(session, close=False) as s: + with self._database.client._tmp_session(session) as s: return cast( RawBatchCursor[_DocumentType], self._aggregate( @@ -3066,7 +3060,6 @@ def aggregate_raw_batches( pipeline, RawBatchCommandCursor, session=s, - explicit_session=session is not None, **kwargs, ), ) diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index bcdeed5f94..a09a67efc9 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -64,7 +64,6 @@ def __init__( batch_size: int = 0, max_await_time_ms: Optional[int] = None, session: Optional[ClientSession] = None, - explicit_session: bool = False, comment: Any = None, ) -> None: """Create a new command cursor.""" @@ -80,7 +79,8 @@ def __init__( self._max_await_time_ms = max_await_time_ms self._timeout = self._collection.database.client.options.timeout self._session = session - self._explicit_session = explicit_session + if self._session is not None: + self._session._attached_to_cursor = True self._killed = self._id == 0 self._comment = comment if self._killed: @@ -197,7 +197,7 @@ def session(self) -> Optional[ClientSession]: .. versionadded:: 3.6 """ - if self._explicit_session: + if self._session and not self._session._implicit: return self._session return None @@ -218,9 +218,10 @@ def _die_no_lock(self) -> None: """Closes this cursor without acquiring a lock.""" cursor_id, address = self._prepare_to_die() self._collection.database.client._cleanup_cursor_no_lock( - cursor_id, address, self._sock_mgr, self._session, self._explicit_session + cursor_id, address, self._sock_mgr, self._session ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None @@ -232,14 +233,15 @@ def _die_lock(self) -> None: address, self._sock_mgr, self._session, - self._explicit_session, ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None def _end_session(self) -> None: - if self._session and not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session._end_implicit_session() self._session = None @@ -430,7 +432,6 @@ def __init__( batch_size: int = 0, max_await_time_ms: Optional[int] = None, session: Optional[ClientSession] = None, - explicit_session: bool = False, comment: Any = None, ) -> None: """Create a new cursor / iterator over raw batches of BSON data. @@ -449,7 +450,6 @@ def __init__( batch_size, max_await_time_ms, session, - explicit_session, comment, ) diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 3dd550f4d5..2cecc5b38a 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -138,10 +138,9 @@ def __init__( if session: self._session = session - self._explicit_session = True + self._session._attached_to_cursor = True else: self._session = None - self._explicit_session = False spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) @@ -150,7 +149,7 @@ def __init__( if not isinstance(limit, int): raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) - if no_cursor_timeout and not self._explicit_session: + if no_cursor_timeout and self._session and self._session._implicit: warnings.warn( "use an explicit session with no_cursor_timeout=True " "otherwise the cursor may still timeout after " @@ -283,7 +282,7 @@ def clone(self) -> Cursor[_DocumentType]: def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: # type: ignore[type-arg] """Internal clone helper.""" if not base: - if self._explicit_session: + if self._session and not self._session._implicit: base = self._clone_base(self._session) else: base = self._clone_base(None) @@ -943,7 +942,7 @@ def session(self) -> Optional[ClientSession]: .. versionadded:: 3.6 """ - if self._explicit_session: + if self._session and not self._session._implicit: return self._session return None @@ -1032,9 +1031,10 @@ def _die_no_lock(self) -> None: cursor_id, address = self._prepare_to_die(already_killed) self._collection.database.client._cleanup_cursor_no_lock( - cursor_id, address, self._sock_mgr, self._session, self._explicit_session + cursor_id, address, self._sock_mgr, self._session ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None @@ -1052,9 +1052,9 @@ def _die_lock(self) -> None: address, self._sock_mgr, self._session, - self._explicit_session, ) - if not self._explicit_session: + if self._session and self._session._implicit: + self._session._attached_to_cursor = False self._session = None self._sock_mgr = None diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index e30f97817c..0d129ba972 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -611,6 +611,8 @@ def create_collection( common.validate_is_mapping("clusteredIndex", clustered_index) with self._client._tmp_session(session) as s: + if s and not s.in_transaction: + s._leave_alive = True # Skip this check in a transaction where listCollections is not # supported. if ( @@ -619,6 +621,8 @@ def create_collection( and name in self._list_collection_names(filter={"name": name}, session=s) ): raise CollectionInvalid("collection %s already exists" % name) + if s: + s._leave_alive = False coll = Collection( self, name, @@ -699,13 +703,12 @@ def aggregate( .. _aggregate command: https://mongodb.com/docs/manual/reference/command/aggregate """ - with self.client._tmp_session(session, close=False) as s: + with self.client._tmp_session(session) as s: cmd = _DatabaseAggregationCommand( self, CommandCursor, pipeline, kwargs, - session is not None, user_fields={"cursor": {"firstBatch": 1}}, ) return self.client._retryable_read( @@ -1009,7 +1012,7 @@ def cursor_command( else: command_name = next(iter(command)) - with self._client._tmp_session(session, close=False) as tmp_session: + with self._client._tmp_session(session) as tmp_session: opts = codec_options or DEFAULT_CODEC_OPTIONS if read_preference is None: @@ -1039,7 +1042,6 @@ def cursor_command( conn.address, max_await_time_ms=max_await_time_ms, session=tmp_session, - explicit_session=session is not None, comment=comment, ) cmd_cursor._maybe_pin_connection(conn) @@ -1085,7 +1087,7 @@ def _list_collections( ) cmd = {"listCollections": 1, "cursor": {}} cmd.update(kwargs) - with self._client._tmp_session(session, close=False) as tmp_session: + with self._client._tmp_session(session) as tmp_session: cursor = ( self._command(conn, cmd, read_preference=read_preference, session=tmp_session) )["cursor"] @@ -1094,7 +1096,6 @@ def _list_collections( cursor, conn.address, session=tmp_session, - explicit_session=session is not None, comment=cmd.get("comment"), ) cmd_cursor._maybe_pin_connection(conn) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ef0663584c..6e716402f4 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2044,17 +2044,18 @@ def _retryable_read( retryable = bool( retryable and self.options.retry_reads and not (session and session.in_transaction) ) - return self._retry_internal( - func, - session, - None, - operation, - is_read=True, - address=address, - read_pref=read_pref, - retryable=retryable, - operation_id=operation_id, - ) + with self._tmp_session(session) as s: + return self._retry_internal( + func, + s, + None, + operation, + is_read=True, + address=address, + read_pref=read_pref, + retryable=retryable, + operation_id=operation_id, + ) def _retryable_write( self, @@ -2087,7 +2088,6 @@ def _cleanup_cursor_no_lock( address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, session: Optional[ClientSession], - explicit_session: bool, ) -> None: """Cleanup a cursor from __del__ without locking. @@ -2102,7 +2102,7 @@ def _cleanup_cursor_no_lock( # The cursor will be closed later in a different session. if cursor_id or conn_mgr: self._close_cursor_soon(cursor_id, address, conn_mgr) - if session and not explicit_session: + if session and session._implicit and not session._leave_alive: session._end_implicit_session() def _cleanup_cursor_lock( @@ -2111,7 +2111,6 @@ def _cleanup_cursor_lock( address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, session: Optional[ClientSession], - explicit_session: bool, ) -> None: """Cleanup a cursor from cursor.close() using a lock. @@ -2123,7 +2122,6 @@ def _cleanup_cursor_lock( :param address: The _CursorAddress. :param conn_mgr: The _ConnectionManager for the pinned connection or None. :param session: The cursor's session. - :param explicit_session: True if the session was passed explicitly. """ if cursor_id: if conn_mgr and conn_mgr.more_to_come: @@ -2136,7 +2134,7 @@ def _cleanup_cursor_lock( self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) if conn_mgr: conn_mgr.close() - if session and not explicit_session: + if session and session._implicit and not session._leave_alive: session._end_implicit_session() def _close_cursor_now( @@ -2217,7 +2215,7 @@ def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: - self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None) except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2262,7 +2260,7 @@ def _return_server_session( @contextlib.contextmanager def _tmp_session( - self, session: Optional[client_session.ClientSession], close: bool = True + self, session: Optional[client_session.ClientSession] ) -> Generator[Optional[client_session.ClientSession], None]: """If provided session is None, lend a temporary session.""" if session is not None: @@ -2287,7 +2285,7 @@ def _tmp_session( raise finally: # Call end_session when we exit this scope. - if close: + if not s._attached_to_cursor: s.end_session() else: yield None diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index 26454b3823..47ac91b0f5 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -218,6 +218,49 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are # Assert that both events occurred on the same mongos. assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id + @async_client_context.require_failCommand_fail_point + async def test_retryable_reads_are_retried_on_the_same_implicit_session(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + directConnection=False, + event_listeners=[listener], + retryReads=True, + ) + + await client.t.t.insert_one({"x": 1}) + + commands = [ + ("aggregate", lambda: client.t.t.count_documents({})), + ("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])), + ("count", lambda: client.t.t.estimated_document_count()), + ("distinct", lambda: client.t.t.distinct("x")), + ("find", lambda: client.t.t.find_one({})), + ("listDatabases", lambda: client.list_databases()), + ("listCollections", lambda: client.t.list_collections()), + ("listIndexes", lambda: client.t.t.list_indexes()), + ] + + for command_name, operation in commands: + listener.reset() + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": {"failCommands": [command_name], "errorCode": 6}, + } + + async with self.fail_point(fail_command): + await operation() + + # Assert that both events occurred on the same session. + command_docs = [ + event.command + for event in listener.started_events + if event.command_name == command_name + ] + self.assertEqual(len(command_docs), 2) + self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"]) + self.assertIsNot(command_docs[0], command_docs[1]) + if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index fb8a374dac..c9f72ae547 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -216,6 +216,49 @@ def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are_avail # Assert that both events occurred on the same mongos. assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id + @client_context.require_failCommand_fail_point + def test_retryable_reads_are_retried_on_the_same_implicit_session(self): + listener = OvertCommandListener() + client = self.rs_or_single_client( + directConnection=False, + event_listeners=[listener], + retryReads=True, + ) + + client.t.t.insert_one({"x": 1}) + + commands = [ + ("aggregate", lambda: client.t.t.count_documents({})), + ("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])), + ("count", lambda: client.t.t.estimated_document_count()), + ("distinct", lambda: client.t.t.distinct("x")), + ("find", lambda: client.t.t.find_one({})), + ("listDatabases", lambda: client.list_databases()), + ("listCollections", lambda: client.t.list_collections()), + ("listIndexes", lambda: client.t.t.list_indexes()), + ] + + for command_name, operation in commands: + listener.reset() + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": {"failCommands": [command_name], "errorCode": 6}, + } + + with self.fail_point(fail_command): + operation() + + # Assert that both events occurred on the same session. + command_docs = [ + event.command + for event in listener.started_events + if event.command_name == command_name + ] + self.assertEqual(len(command_docs), 2) + self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"]) + self.assertIsNot(command_docs[0], command_docs[1]) + if __name__ == "__main__": unittest.main()