Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pymongo/asynchronous/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
cursor_class: type[AsyncCommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
Expand Down Expand Up @@ -92,7 +91,6 @@ def __init__(
self._options["cursor"]["batchSize"] = self._batch_size

self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor

Expand Down Expand Up @@ -197,7 +195,6 @@ async def get_cursor(
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down
11 changes: 5 additions & 6 deletions pymongo/asynchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N
)

async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
self, session: Optional[AsyncClientSession]
) -> AsyncCommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
Expand All @@ -246,7 +246,6 @@ async def _run_aggregation_cmd(
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
Expand All @@ -258,10 +257,10 @@ async def _run_aggregation_cmd(
)

async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async with self._client._tmp_session(self._session) as s:
if s:
s.leave_alive = True
return await self._run_aggregation_cmd(session=s)

async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ async def _process_results_cursor(
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down Expand Up @@ -538,6 +537,7 @@ async def _execute_command(
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
session.leave_alive = True
conn.send_cluster_time(cmd, session, self.client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
Expand Down
30 changes: 29 additions & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._attached_to_cursor = False
self._leave_alive = False

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand All @@ -535,7 +537,7 @@ async def _end_session(self, lock: bool) -> None:

def _end_implicit_session(self) -> None:
# Implicit sessions can't be part of transactions or pinned connections
if self._server_session is not None:
if not self._leave_alive and self._server_session is not None:
self._client._return_server_session(self._server_session)
self._server_session = None

Expand Down Expand Up @@ -588,6 +590,32 @@ def operation_time(self) -> Optional[Timestamp]:
"""
return self._operation_time

@property
def implicit(self) -> bool:
"""Whether this session was implicitly created by the driver."""
return self._implicit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is personal preference but do we really need these @property helpers? Usually we just access the private attribute directly, eg:

if session._implicit:...
if session._attached_to_cursor:...

This way there's less indirection and boilerplate code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For internal attributes it makes more sense to not have @property, agreed.


@property
def attached_to_cursor(self) -> bool:
"""Whether this session is owned by a cursor."""
return self._attached_to_cursor

@attached_to_cursor.setter
def attached_to_cursor(self, value: bool) -> None:
self._attached_to_cursor = value

@property
def leave_alive(self) -> bool:
"""Whether to leave this session alive when it is
no longer in use.
Typically used for implicit sessions that are used for multiple operations within a single larger operation.
"""
return self._leave_alive

@leave_alive.setter
def leave_alive(self, value: bool) -> None:
self._leave_alive = value

def _inherit_option(self, name: str, val: _T) -> _T:
"""Return the inherited TransactionOption value."""
if val:
Expand Down
13 changes: 3 additions & 10 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,7 +2549,6 @@ async def _list_indexes(
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
)
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
explicit_session = session is not None

async def _cmd(
session: Optional[AsyncClientSession],
Expand All @@ -2576,13 +2575,12 @@ async def _cmd(
cursor,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor

async with self._database.client._tmp_session(session, False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._database.client._retryable_read(
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
)
Expand Down Expand Up @@ -2678,7 +2676,6 @@ async def list_search_indexes(
AsyncCommandCursor,
pipeline,
kwargs,
explicit_session=session is not None,
comment=comment,
user_fields={"cursor": {"firstBatch": 1}},
)
Expand Down Expand Up @@ -2900,7 +2897,6 @@ async def _aggregate(
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
Expand All @@ -2912,7 +2908,6 @@ async def _aggregate(
cursor_class,
pipeline,
kwargs,
explicit_session,
let,
user_fields={"cursor": {"firstBatch": 1}},
)
Expand Down Expand Up @@ -3018,13 +3013,12 @@ async def aggregate(
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._aggregate(
_CollectionAggregationCommand,
pipeline,
AsyncCommandCursor,
session=s,
explicit_session=session is not None,
let=let,
comment=comment,
**kwargs,
Expand Down Expand Up @@ -3065,15 +3059,14 @@ async def aggregate_raw_batches(
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
if comment is not None:
kwargs["comment"] = comment
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return cast(
AsyncRawBatchCursor[_DocumentType],
await self._aggregate(
_CollectionRawAggregationCommand,
pipeline,
AsyncRawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs,
),
)
Expand Down
20 changes: 10 additions & 10 deletions pymongo/asynchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
Expand All @@ -80,7 +79,8 @@ def __init__(
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
if self._session is not None:
self._session.attached_to_cursor = True
self._killed = self._id == 0
self._comment = comment
if self._killed:
Expand Down Expand Up @@ -197,7 +197,7 @@ def session(self) -> Optional[AsyncClientSession]:

.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session.implicit:
return self._session
return None

Expand All @@ -218,9 +218,10 @@ def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session.implicit:
self._session.attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand All @@ -232,14 +233,15 @@ async def _die_lock(self) -> None:
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session.implicit:
self._session.attached_to_cursor = False
self._session = None
self._sock_mgr = None

def _end_session(self) -> None:
if self._session and not self._explicit_session:
if self._session and self._session.implicit and not self._session.leave_alive:
self._session.attached_to_cursor = False
self._session._end_implicit_session()
self._session = None

Expand Down Expand Up @@ -430,7 +432,6 @@ def __init__(
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Expand All @@ -449,7 +450,6 @@ def __init__(
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)

Expand Down
18 changes: 9 additions & 9 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def __init__(

if session:
self._session = session
self._explicit_session = True
self._session.attached_to_cursor = True
else:
self._session = None
self._explicit_session = False

spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
Expand All @@ -150,7 +149,7 @@ def __init__(
if not isinstance(limit, int):
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
if no_cursor_timeout and self._session and self._session.implicit:
warnings.warn(
"use an explicit session with no_cursor_timeout=True "
"otherwise the cursor may still timeout after "
Expand Down Expand Up @@ -283,7 +282,7 @@ def clone(self) -> AsyncCursor[_DocumentType]:
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
if self._session and not self._session.implicit:
base = self._clone_base(self._session)
else:
base = self._clone_base(None)
Expand Down Expand Up @@ -945,7 +944,7 @@ def session(self) -> Optional[AsyncClientSession]:

.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session.implicit:
return self._session
return None

Expand Down Expand Up @@ -1034,9 +1033,10 @@ def _die_no_lock(self) -> None:

cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session.implicit:
self._session.attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand All @@ -1054,9 +1054,9 @@ async def _die_lock(self) -> None:
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session.implicit:
self._session.attached_to_cursor = False
self._session = None
self._sock_mgr = None

Expand Down
11 changes: 5 additions & 6 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ async def create_collection(
common.validate_is_mapping("clusteredIndex", clustered_index)

async with self._client._tmp_session(session) as s:
if s:
s.leave_alive = True
# Skip this check in a transaction where listCollections is not
# supported.
if (
Expand Down Expand Up @@ -699,13 +701,12 @@ async def aggregate(
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self.client._tmp_session(session, close=False) as s:
async with self.client._tmp_session(session) as s:
cmd = _DatabaseAggregationCommand(
self,
AsyncCommandCursor,
pipeline,
kwargs,
session is not None,
user_fields={"cursor": {"firstBatch": 1}},
)
return await self.client._retryable_read(
Expand Down Expand Up @@ -1011,7 +1012,7 @@ async def cursor_command(
else:
command_name = next(iter(command))

async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS

if read_preference is None:
Expand Down Expand Up @@ -1043,7 +1044,6 @@ async def cursor_command(
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down Expand Up @@ -1089,7 +1089,7 @@ async def _list_collections(
)
cmd = {"listCollections": 1, "cursor": {}}
cmd.update(kwargs)
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
cursor = (
await self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
)["cursor"]
Expand All @@ -1098,7 +1098,6 @@ async def _list_collections(
cursor,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
Expand Down
Loading
Loading