Skip to content

Commit c217cae

Browse files
committed
Fix tests and Improve commit handler so that it supports UnknownTransactionCommitResult
1 parent c4adaf9 commit c217cae

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

mongoengine/context_managers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
from contextlib import contextmanager
44

5+
from pymongo.errors import ConnectionFailure, OperationFailure
56
from pymongo.read_concern import ReadConcern
67
from pymongo.write_concern import WriteConcern
78

@@ -321,6 +322,23 @@ def set_read_write_concern(collection, write_concerns, read_concerns):
321322
)
322323

323324

325+
def _commit_with_retry(session):
326+
while True:
327+
try:
328+
# Commit uses write concern set at transaction start.
329+
session.commit_transaction()
330+
print("Transaction committed.")
331+
break
332+
except (ConnectionFailure, OperationFailure) as exc:
333+
# Can retry commit
334+
if exc.has_error_label("UnknownTransactionCommitResult"):
335+
print("UnknownTransactionCommitResult, retrying commit operation ...")
336+
continue
337+
else:
338+
print("Error during commit ...")
339+
raise
340+
341+
324342
@contextmanager
325343
def run_in_transaction(
326344
alias=DEFAULT_CONNECTION_NAME, session_kwargs=None, transaction_kwargs=None
@@ -355,5 +373,6 @@ class A(Document):
355373
try:
356374
_set_session(session)
357375
yield
376+
_commit_with_retry(session)
358377
finally:
359378
_clear_session()

tests/test_connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,10 @@ def test_multiple_connection_settings(self):
658658
mongo_connections["t2"].server_info()
659659

660660
assert mongo_connections["t1"].address[0] == "localhost"
661-
assert (
662-
mongo_connections["t2"].address[0] == "localhost"
663-
) # weird but we have this with replicaset
661+
assert mongo_connections["t2"].address[0] in (
662+
"localhost",
663+
"127.0.0.1",
664+
) # weird but there is a discrepancy in the address in replicaset setup
664665
assert mongo_connections["t1"].read_preference == ReadPreference.PRIMARY
665666
assert (
666667
mongo_connections["t2"].read_preference == ReadPreference.PRIMARY_PREFERRED

tests/test_context_managers.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import random
23
import time
34
import unittest
@@ -28,7 +29,7 @@
2829

2930

3031
class TestRollbackError(Exception):
31-
pass
32+
__test__ = False # Silence pytest warning
3233

3334

3435
class TestableThread(Thread):
@@ -38,6 +39,8 @@ class TestableThread(Thread):
3839
REF: https://gist.github.com/sbrugman/59b3535ebcd5aa0e2598293cfa58b6ab
3940
"""
4041

42+
__test__ = False # Silence pytest warning
43+
4144
def __init__(self, *args, **kwargs):
4245
super().__init__(*args, **kwargs)
4346
self.exc = None
@@ -753,14 +756,15 @@ def run_tx():
753756

754757
except TestRollbackError:
755758
pass
756-
# except OperationError as op_failure:
757-
# """
758-
# See thread safety test below for more details about TransientTransactionError handling
759-
# """
760-
# if "TransientTransactionError" in str(op_failure):
761-
# run_tx()
762-
# else:
763-
# raise op_failure
759+
except OperationError as op_failure:
760+
"""
761+
See thread safety test below for more details about TransientTransactionError handling
762+
"""
763+
if "TransientTransactionError" in str(op_failure):
764+
logging.warning("TransientTransactionError - retrying...")
765+
run_tx()
766+
else:
767+
raise op_failure
764768

765769
run_tx()
766770
assert "a" == A.objects.get(id=a_doc.id).name
@@ -789,10 +793,10 @@ def test_thread_safety_of_transactions(self):
789793
case, then no amount of runtime variability should have
790794
an effect on the output.
791795
792-
This test sets up 10 records, each with an integer field
796+
This test sets up e.g 10 records, each with an integer field
793797
of value 0 - 9.
794798
795-
We then spin up 10 threads and attempt to update a target
799+
We then spin up e.g 10 threads and attempt to update a target
796800
record by multiplying its integer value by 10. Then, if
797801
the target record is even, throw an exception, which
798802
should then roll the transaction back. The odd rows always
@@ -807,24 +811,26 @@ def test_thread_safety_of_transactions(self):
807811
connect("mongoenginetest")
808812

809813
class A(Document):
810-
i = IntField()
814+
i = IntField(unique=True)
811815

812816
A.drop_collection()
813817
# Ensure the collection is created
814-
A.objects.create(i=0)
818+
_ = A.objects.first()
819+
820+
thread_count = 20
815821

816822
def thread_fn(idx):
817823
# Open the transaction at some unknown interval
818824
time.sleep(random.uniform(0.1, 0.5))
819825
try:
820826
with run_in_transaction():
821827
a = A.objects.get(i=idx)
822-
a.i = idx * 10
828+
a.i = idx * thread_count
823829
# Save at some unknown interval
824830
time.sleep(random.uniform(0.1, 0.5))
825831
a.save()
826832

827-
# Force roll backs for the even runs...
833+
# Force rollbacks for the even runs...
828834
if idx % 2 == 0:
829835
raise TestRollbackError()
830836
except TestRollbackError:
@@ -841,6 +847,7 @@ def thread_fn(idx):
841847
"""
842848
error_labels = op_failure.details.get("errorLabels", [])
843849
if "TransientTransactionError" in error_labels:
850+
logging.warning("TransientTransactionError - retrying...")
844851
thread_fn(idx)
845852
else:
846853
raise op_failure
@@ -854,7 +861,6 @@ def thread_fn(idx):
854861
A.objects.all().delete()
855862

856863
# Prepopulate the data for reads
857-
thread_count = 20
858864
for i in range(thread_count):
859865
A.objects.create(i=i)
860866

@@ -868,13 +874,10 @@ def thread_fn(idx):
868874
t.join()
869875

870876
# Check the sum
871-
expected_sum = 0
872-
for i in range(thread_count):
873-
if i % 2 == 0:
874-
expected_sum += i
875-
else:
876-
expected_sum += i * 10
877-
assert expected_sum == 1090
877+
expected_sum = sum(
878+
i if i % 2 == 0 else i * thread_count for i in range(thread_count)
879+
)
880+
assert expected_sum == 2090
878881
assert expected_sum == sum(a.i for a in A.objects.all())
879882

880883

0 commit comments

Comments
 (0)