Skip to content

PYTHON-4542 Improved sessions API #2335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -222,6 +224,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to bind/unbind the session in ClientSession.__enter__/__exit__. That way the stack of sessions is managed correctly (ie we call _SESSION.reset(token)). Think about how nested cases will work:

session1 = client.start_session(bind=True)
with session1:
    session2 = client.start_session(bind=True)
    with session2:
        coll.find_one() # uses session2
    coll.find_one() # uses session1
coll.find_one() # uses implicit session


@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -1065,6 +1068,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")


SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 6 additions & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from _typeshed import SupportsItems

from bson.codec_options import CodecOptions
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.client_session import SESSION, AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.read_preferences import _ServerMode
Expand Down Expand Up @@ -136,9 +136,14 @@ def __init__(
self._killed = False
self._session: Optional[AsyncClientSession]

_SESSION = SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif _SESSION:
self._session = _SESSION
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1355,13 +1355,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.AsyncClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.AsyncClientSession(self, server_session, opts, implicit)
if bind:
SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.AsyncClientSession:
"""Start a logical session.

Expand All @@ -1384,6 +1389,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(
Expand Down
6 changes: 6 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -221,6 +223,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind

@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -1060,6 +1063,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")


SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 6 additions & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

from bson.codec_options import CodecOptions
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.client_session import SESSION, ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.pool import Connection

Expand Down Expand Up @@ -136,9 +136,14 @@ def __init__(
self._killed = False
self._session: Optional[ClientSession]

_SESSION = SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif _SESSION:
self._session = _SESSION
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1353,13 +1353,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.ClientSession(self, server_session, opts, implicit)
if bind:
SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.ClientSession:
"""Start a logical session.

Expand All @@ -1382,6 +1387,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
Expand Down
7 changes: 7 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ async def test_cursor_clone(self):
clone = cursor.clone()
self.assertTrue(clone.session is s)

# Explicit session via context variable.
async with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)
clone = cursor.clone()
self.assertTrue(clone.session is s)

# No explicit session.
cursor = coll.find(batch_size=2)
await anext(cursor)
Expand Down
7 changes: 7 additions & 0 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ def test_cursor_clone(self):
clone = cursor.clone()
self.assertTrue(clone.session is s)

# Explicit session via context variable.
with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)
clone = cursor.clone()
self.assertTrue(clone.session is s)

# No explicit session.
cursor = coll.find(batch_size=2)
next(cursor)
Expand Down
Loading