Skip to content

Commit 8d5a976

Browse files
authored
Merge pull request #2839 from bagerard/juannyG-transactions-global-session-ver
Juanny g - transactions global session ver
2 parents 3305136 + a81a568 commit 8d5a976

File tree

14 files changed

+606
-56
lines changed

14 files changed

+606
-56
lines changed

.github/workflows/start_mongo.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ MONGODB=$1
55
mongodb_dir=$(find ${PWD}/ -type d -name "mongodb-linux-x86_64*")
66

77
mkdir $mongodb_dir/data
8-
$mongodb_dir/bin/mongod --dbpath $mongodb_dir/data --logpath $mongodb_dir/mongodb.log --fork
98

9+
$mongodb_dir/bin/mongod --dbpath $mongodb_dir/data --logpath $mongodb_dir/mongodb.log --fork --replSet mongoengine
1010
if (( $(echo "$MONGODB < 6.0" | bc -l) )); then
11-
mongo --quiet --eval 'db.runCommand("ping").ok' # Make sure mongo is awake
11+
mongo --verbose --eval "rs.initiate()"
12+
mongo --quiet --eval "rs.status().ok"
1213
else
13-
mongosh --quiet --eval 'db.runCommand("ping").ok' # Make sure mongo is awake
14+
mongosh --verbose --eval "rs.initiate()"
15+
mongosh --quiet --eval "rs.status().ok"
1416
fi

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,4 @@ that much better:
265265
* Ido Shraga (https://github.com/idoshr)
266266
* Terence Honles (https://github.com/terencehonles)
267267
* Sean Bermejo (https://github.com/seanbermejo)
268+
* Juan Gutierrez (https://github.com/juannyg)

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM mongo:5
1+
FROM mongo:4.0
22

33
COPY ./entrypoint.sh entrypoint.sh
44
RUN chmod u+x entrypoint.sh

docs/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ Changelog
77
Development
88
===========
99
- (Fill this out as you fix issues and develop your features).
10+
- Add support for transaction through run_in_transaction (kudos to juannyG for this) #2569
11+
Some considerations:
12+
- make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api
13+
- run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions
14+
- Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions)
1015

1116
Changes in 0.29.0
1217
=================

mongoengine/connection.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import collections
2+
import threading
13
import warnings
24

35
from pymongo import MongoClient, ReadPreference, uri_parser
@@ -35,6 +37,7 @@
3537
_connections = {}
3638
_dbs = {}
3739

40+
3841
READ_PREFERENCE = ReadPreference.PRIMARY
3942

4043

@@ -470,3 +473,37 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
470473
# Support old naming convention
471474
_get_connection = get_connection
472475
_get_db = get_db
476+
477+
478+
class _LocalSessions(threading.local):
479+
def __init__(self):
480+
self.sessions = collections.deque()
481+
482+
def append(self, session):
483+
self.sessions.append(session)
484+
485+
def get_current(self):
486+
if len(self.sessions):
487+
return self.sessions[-1]
488+
489+
def clear_current(self):
490+
if len(self.sessions):
491+
self.sessions.pop()
492+
493+
def clear_all(self):
494+
self.sessions.clear()
495+
496+
497+
_local_sessions = _LocalSessions()
498+
499+
500+
def _set_session(session):
501+
_local_sessions.append(session)
502+
503+
504+
def _get_session():
505+
return _local_sessions.get_current()
506+
507+
508+
def _clear_session():
509+
return _local_sessions.clear_current()

mongoengine/context_managers.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import contextlib
2+
import logging
23
import threading
34
from contextlib import contextmanager
45

6+
from pymongo.errors import ConnectionFailure, OperationFailure
57
from pymongo.read_concern import ReadConcern
68
from pymongo.write_concern import WriteConcern
79

810
from mongoengine.base.fields import _no_dereference_for_fields
911
from mongoengine.common import _import_class
10-
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
12+
from mongoengine.connection import (
13+
DEFAULT_CONNECTION_NAME,
14+
_clear_session,
15+
_get_session,
16+
_set_session,
17+
get_connection,
18+
get_db,
19+
)
1120
from mongoengine.pymongo_support import count_documents
1221

1322
__all__ = (
@@ -19,6 +28,7 @@
1928
"set_write_concern",
2029
"set_read_write_concern",
2130
"no_dereferencing_active_for_class",
31+
"run_in_transaction",
2232
)
2333

2434

@@ -231,11 +241,11 @@ def __init__(self, alias=DEFAULT_CONNECTION_NAME):
231241
}
232242

233243
def _turn_on_profiling(self):
234-
profile_update_res = self.db.command({"profile": 0})
244+
profile_update_res = self.db.command({"profile": 0}, session=_get_session())
235245
self.initial_profiling_level = profile_update_res["was"]
236246

237247
self.db.system.profile.drop()
238-
self.db.command({"profile": 2})
248+
self.db.command({"profile": 2}, session=_get_session())
239249

240250
def _resets_profiling(self):
241251
self.db.command({"profile": self.initial_profiling_level})
@@ -311,3 +321,60 @@ def set_read_write_concern(collection, write_concerns, read_concerns):
311321
write_concern=WriteConcern(**combined_write_concerns),
312322
read_concern=ReadConcern(**combined_read_concerns),
313323
)
324+
325+
326+
def _commit_with_retry(session):
327+
while True:
328+
try:
329+
# Commit uses write concern set at transaction start.
330+
session.commit_transaction()
331+
break
332+
except (ConnectionFailure, OperationFailure) as exc:
333+
# Can retry commit
334+
if exc.has_error_label("UnknownTransactionCommitResult"):
335+
logging.warning(
336+
"UnknownTransactionCommitResult, retrying commit operation ..."
337+
)
338+
continue
339+
else:
340+
# Error during commit
341+
raise
342+
343+
344+
@contextmanager
345+
def run_in_transaction(
346+
alias=DEFAULT_CONNECTION_NAME, session_kwargs=None, transaction_kwargs=None
347+
):
348+
"""run_in_transaction context manager
349+
Execute queries within the context in a database transaction.
350+
351+
Usage:
352+
353+
.. code-block:: python
354+
355+
class A(Document):
356+
name = StringField()
357+
358+
with run_in_transaction():
359+
a_doc = A.objects.create(name="a")
360+
a_doc.update(name="b")
361+
362+
Be aware that:
363+
- Mongo transactions run inside a session which is bound to a connection. If you attempt to
364+
execute a transaction across a different connection alias, pymongo will raise an exception. In
365+
other words: you cannot create a transaction that crosses different database connections. That
366+
said, multiple transaction can be nested within the same session for particular connection.
367+
368+
For more information regarding pymongo transactions: https://pymongo.readthedocs.io/en/stable/api/pymongo/client_session.html#transactions
369+
"""
370+
conn = get_connection(alias)
371+
session_kwargs = session_kwargs or {}
372+
with conn.start_session(**session_kwargs) as session:
373+
transaction_kwargs = transaction_kwargs or {}
374+
with session.start_transaction(**transaction_kwargs):
375+
try:
376+
_set_session(session)
377+
yield
378+
_commit_with_retry(session)
379+
finally:
380+
_clear_session()

mongoengine/dereference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
get_document,
99
)
1010
from mongoengine.base.datastructures import LazyReference
11-
from mongoengine.connection import get_db
11+
from mongoengine.connection import _get_session, get_db
1212
from mongoengine.document import Document, EmbeddedDocument
1313
from mongoengine.fields import (
1414
DictField,
@@ -187,13 +187,15 @@ def _fetch_objects(self, doc_type=None):
187187

188188
if doc_type:
189189
references = doc_type._get_db()[collection].find(
190-
{"_id": {"$in": refs}}
190+
{"_id": {"$in": refs}}, session=_get_session()
191191
)
192192
for ref in references:
193193
doc = doc_type._from_son(ref)
194194
object_map[(collection, doc.id)] = doc
195195
else:
196-
references = get_db()[collection].find({"_id": {"$in": refs}})
196+
references = get_db()[collection].find(
197+
{"_id": {"$in": refs}}, session=_get_session()
198+
)
197199
for ref in references:
198200
if "_cls" in ref:
199201
doc = get_document(ref["_cls"])._from_son(ref)

mongoengine/document.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
get_document,
1616
)
1717
from mongoengine.common import _import_class
18-
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
18+
from mongoengine.connection import (
19+
DEFAULT_CONNECTION_NAME,
20+
_get_session,
21+
get_db,
22+
)
1923
from mongoengine.context_managers import (
2024
set_write_concern,
2125
switch_collection,
@@ -271,7 +275,7 @@ def _get_capped_collection(cls):
271275
if max_documents:
272276
opts["max"] = max_documents
273277

274-
return db.create_collection(collection_name, **opts)
278+
return db.create_collection(collection_name, session=_get_session(), **opts)
275279

276280
@classmethod
277281
def _get_timeseries_collection(cls):
@@ -502,17 +506,21 @@ def _save_create(self, doc, force_insert, write_concern):
502506
collection = self._get_collection()
503507
with set_write_concern(collection, write_concern) as wc_collection:
504508
if force_insert:
505-
return wc_collection.insert_one(doc).inserted_id
509+
return wc_collection.insert_one(doc, session=_get_session()).inserted_id
506510
# insert_one will provoke UniqueError alongside save does not
507511
# therefore, it need to catch and call replace_one.
508512
if "_id" in doc:
509513
select_dict = {"_id": doc["_id"]}
510514
select_dict = self._integrate_shard_key(doc, select_dict)
511-
raw_object = wc_collection.find_one_and_replace(select_dict, doc)
515+
raw_object = wc_collection.find_one_and_replace(
516+
select_dict, doc, session=_get_session()
517+
)
512518
if raw_object:
513519
return doc["_id"]
514520

515-
object_id = wc_collection.insert_one(doc).inserted_id
521+
object_id = wc_collection.insert_one(
522+
doc, session=_get_session()
523+
).inserted_id
516524

517525
return object_id
518526

@@ -570,7 +578,7 @@ def _save_update(self, doc, save_condition, write_concern):
570578
upsert = save_condition is None
571579
with set_write_concern(collection, write_concern) as wc_collection:
572580
last_error = wc_collection.update_one(
573-
select_dict, update_doc, upsert=upsert
581+
select_dict, update_doc, upsert=upsert, session=_get_session()
574582
).raw_result
575583
if not upsert and last_error["n"] == 0:
576584
raise SaveConditionError(
@@ -873,7 +881,7 @@ def drop_collection(cls):
873881
)
874882
cls._collection = None
875883
db = cls._get_db()
876-
db.drop_collection(coll_name)
884+
db.drop_collection(coll_name, session=_get_session())
877885

878886
@classmethod
879887
def create_index(cls, keys, background=False, **kwargs):
@@ -890,7 +898,9 @@ def create_index(cls, keys, background=False, **kwargs):
890898
index_spec["background"] = background
891899
index_spec.update(kwargs)
892900

893-
return cls._get_collection().create_index(fields, **index_spec)
901+
return cls._get_collection().create_index(
902+
fields, session=_get_session(), **index_spec
903+
)
894904

895905
@classmethod
896906
def ensure_indexes(cls):
@@ -936,7 +946,9 @@ def ensure_indexes(cls):
936946
if "cls" in opts:
937947
del opts["cls"]
938948

939-
collection.create_index(fields, background=background, **opts)
949+
collection.create_index(
950+
fields, background=background, session=_get_session(), **opts
951+
)
940952

941953
# If _cls is being used (for polymorphism), it needs an index,
942954
# only if another index doesn't begin with _cls
@@ -946,7 +958,9 @@ def ensure_indexes(cls):
946958
if "cls" in index_opts:
947959
del index_opts["cls"]
948960

949-
collection.create_index("_cls", background=background, **index_opts)
961+
collection.create_index(
962+
"_cls", background=background, session=_get_session(), **index_opts
963+
)
950964

951965
@classmethod
952966
def list_indexes(cls):
@@ -1024,7 +1038,7 @@ def compare_indexes(cls):
10241038

10251039
existing = []
10261040
collection = cls._get_collection()
1027-
for info in collection.index_information().values():
1041+
for info in collection.index_information(session=_get_session()).values():
10281042
if "_fts" in info["key"][0]:
10291043
# Useful for text indexes (but not only)
10301044
index_type = info["key"][0][1]

0 commit comments

Comments
 (0)