Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pymongo/asynchronous/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
cursor_class: type[AsyncCommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
Expand Down Expand Up @@ -92,7 +91,6 @@ def __init__(
self._options["cursor"]["batchSize"] = self._batch_size

self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor

Expand Down Expand Up @@ -197,7 +195,6 @@ async def get_cursor(
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down
9 changes: 3 additions & 6 deletions pymongo/asynchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N
)

async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
self, session: Optional[AsyncClientSession]
) -> AsyncCommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
Expand All @@ -246,7 +246,6 @@ async def _run_aggregation_cmd(
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
Expand All @@ -258,10 +257,8 @@ async def _run_aggregation_cmd(
)

async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async with self._client._tmp_session(self._session) as s:
return await self._run_aggregation_cmd(session=s)

async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ async def _process_results_cursor(
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down Expand Up @@ -538,6 +537,7 @@ async def _execute_command(
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
session.leave_alive = True
conn.send_cluster_time(cmd, session, self.client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
Expand Down
30 changes: 29 additions & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._attached_to_cursor = False
self._leave_alive = False

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand All @@ -535,7 +537,7 @@ async def _end_session(self, lock: bool) -> None:

def _end_implicit_session(self) -> None:
# Implicit sessions can't be part of transactions or pinned connections
if self._server_session is not None:
if not self._leave_alive and self._server_session is not None:
self._client._return_server_session(self._server_session)
self._server_session = None

Expand Down Expand Up @@ -588,6 +590,32 @@ def operation_time(self) -> Optional[Timestamp]:
"""
return self._operation_time

@property
def _is_implicit(self) -> bool:
"""Whether this session was implicitly created by the driver."""
return self._implicit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is personal preference but do we really need these @property helpers? Usually we just access the private attribute directly, eg:

if session._implicit:...
if session._attached_to_cursor:...

This way there's less indirection and boilerplate code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For internal attributes it makes more sense to not have @property, agreed.


@property
def _is_attached_to_cursor(self) -> bool:
"""Whether this session is owned by a cursor."""
return self._attached_to_cursor

@_is_attached_to_cursor.setter
def _is_attached_to_cursor(self, value: bool) -> None:
self._attached_to_cursor = value

@property
def leave_alive(self) -> bool:
"""Whether to leave this session alive when it is
no longer in use.
Typically used for implicit sessions that are used for multiple operations within a single larger operation.
"""
return self._leave_alive

@leave_alive.setter
def leave_alive(self, value: bool) -> None:
self._leave_alive = value

def _inherit_option(self, name: str, val: _T) -> _T:
"""Return the inherited TransactionOption value."""
if val:
Expand Down
13 changes: 3 additions & 10 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,7 +2549,6 @@ async def _list_indexes(
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
)
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
explicit_session = session is not None

async def _cmd(
session: Optional[AsyncClientSession],
Expand All @@ -2576,13 +2575,12 @@ async def _cmd(
cursor,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor

async with self._database.client._tmp_session(session, False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._database.client._retryable_read(
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
)
Expand Down Expand Up @@ -2678,7 +2676,6 @@ async def list_search_indexes(
AsyncCommandCursor,
pipeline,
kwargs,
explicit_session=session is not None,
comment=comment,
user_fields={"cursor": {"firstBatch": 1}},
)
Expand Down Expand Up @@ -2900,7 +2897,6 @@ async def _aggregate(
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
Expand All @@ -2912,7 +2908,6 @@ async def _aggregate(
cursor_class,
pipeline,
kwargs,
explicit_session,
let,
user_fields={"cursor": {"firstBatch": 1}},
)
Expand Down Expand Up @@ -3018,13 +3013,12 @@ async def aggregate(
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._aggregate(
_CollectionAggregationCommand,
pipeline,
AsyncCommandCursor,
session=s,
explicit_session=session is not None,
let=let,
comment=comment,
**kwargs,
Expand Down Expand Up @@ -3065,15 +3059,14 @@ async def aggregate_raw_batches(
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
if comment is not None:
kwargs["comment"] = comment
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return cast(
AsyncRawBatchCursor[_DocumentType],
await self._aggregate(
_CollectionRawAggregationCommand,
pipeline,
AsyncRawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs,
),
)
Expand Down
20 changes: 10 additions & 10 deletions pymongo/asynchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
Expand All @@ -80,7 +79,8 @@ def __init__(
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
if self._session is not None:
self._session._is_attached_to_cursor = True
self._killed = self._id == 0
self._comment = comment
if self._killed:
Expand Down Expand Up @@ -197,7 +197,7 @@ def session(self) -> Optional[AsyncClientSession]:

.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._is_implicit:
return self._session
return None

Expand All @@ -218,9 +218,10 @@ def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._is_implicit:
self._session._is_attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand All @@ -232,14 +233,15 @@ async def _die_lock(self) -> None:
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._is_implicit:
self._session._is_attached_to_cursor = False
self._session = None
self._sock_mgr = None

def _end_session(self) -> None:
if self._session and not self._explicit_session:
if self._session and self._session._is_implicit and not self._session.leave_alive:
self._session._is_attached_to_cursor = False
self._session._end_implicit_session()
self._session = None

Expand Down Expand Up @@ -430,7 +432,6 @@ def __init__(
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Expand All @@ -449,7 +450,6 @@ def __init__(
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)

Expand Down
18 changes: 9 additions & 9 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def __init__(

if session:
self._session = session
self._explicit_session = True
self._session._is_attached_to_cursor = True
else:
self._session = None
self._explicit_session = False

spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
Expand All @@ -150,7 +149,7 @@ def __init__(
if not isinstance(limit, int):
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
if no_cursor_timeout and self._session and self._session._is_implicit:
warnings.warn(
"use an explicit session with no_cursor_timeout=True "
"otherwise the cursor may still timeout after "
Expand Down Expand Up @@ -283,7 +282,7 @@ def clone(self) -> AsyncCursor[_DocumentType]:
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
if self._session and not self._session._is_implicit:
base = self._clone_base(self._session)
else:
base = self._clone_base(None)
Expand Down Expand Up @@ -945,7 +944,7 @@ def session(self) -> Optional[AsyncClientSession]:

.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._is_implicit:
return self._session
return None

Expand Down Expand Up @@ -1034,9 +1033,10 @@ def _die_no_lock(self) -> None:

cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._is_implicit:
self._session._is_attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand All @@ -1054,9 +1054,9 @@ async def _die_lock(self) -> None:
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._is_implicit:
self._session._is_attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand Down
11 changes: 5 additions & 6 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ async def create_collection(
common.validate_is_mapping("clusteredIndex", clustered_index)

async with self._client._tmp_session(session) as s:
if s:
s.leave_alive = True
# Skip this check in a transaction where listCollections is not
# supported.
if (
Expand Down Expand Up @@ -699,13 +701,12 @@ async def aggregate(
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self.client._tmp_session(session, close=False) as s:
async with self.client._tmp_session(session) as s:
cmd = _DatabaseAggregationCommand(
self,
AsyncCommandCursor,
pipeline,
kwargs,
session is not None,
user_fields={"cursor": {"firstBatch": 1}},
)
return await self.client._retryable_read(
Expand Down Expand Up @@ -1011,7 +1012,7 @@ async def cursor_command(
else:
command_name = next(iter(command))

async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS

if read_preference is None:
Expand Down Expand Up @@ -1043,7 +1044,6 @@ async def cursor_command(
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down Expand Up @@ -1089,7 +1089,7 @@ async def _list_collections(
)
cmd = {"listCollections": 1, "cursor": {}}
cmd.update(kwargs)
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
cursor = (
await self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
)["cursor"]
Expand All @@ -1098,7 +1098,6 @@ async def _list_collections(
cursor,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down
Loading
Loading