1
+ import logging
1
2
import random
2
3
import time
3
4
import unittest
28
29
29
30
30
31
class TestRollbackError (Exception ):
31
- pass
32
+ __test__ = False # Silence pytest warning
32
33
33
34
34
35
class TestableThread (Thread ):
@@ -38,6 +39,8 @@ class TestableThread(Thread):
38
39
REF: https://gist.github.com/sbrugman/59b3535ebcd5aa0e2598293cfa58b6ab
39
40
"""
40
41
42
+ __test__ = False # Silence pytest warning
43
+
41
44
def __init__ (self , * args , ** kwargs ):
42
45
super ().__init__ (* args , ** kwargs )
43
46
self .exc = None
@@ -753,14 +756,15 @@ def run_tx():
753
756
754
757
except TestRollbackError :
755
758
pass
756
- # except OperationError as op_failure:
757
- # """
758
- # See thread safety test below for more details about TransientTransactionError handling
759
- # """
760
- # if "TransientTransactionError" in str(op_failure):
761
- # run_tx()
762
- # else:
763
- # raise op_failure
759
+ except OperationError as op_failure :
760
+ """
761
+ See thread safety test below for more details about TransientTransactionError handling
762
+ """
763
+ if "TransientTransactionError" in str (op_failure ):
764
+ logging .warning ("TransientTransactionError - retrying..." )
765
+ run_tx ()
766
+ else :
767
+ raise op_failure
764
768
765
769
run_tx ()
766
770
assert "a" == A .objects .get (id = a_doc .id ).name
@@ -789,10 +793,10 @@ def test_thread_safety_of_transactions(self):
789
793
case, then no amount of runtime variability should have
790
794
an effect on the output.
791
795
792
- This test sets up 10 records, each with an integer field
796
+ This test sets up e.g 10 records, each with an integer field
793
797
of value 0 - 9.
794
798
795
- We then spin up 10 threads and attempt to update a target
799
+ We then spin up e.g 10 threads and attempt to update a target
796
800
record by multiplying its integer value by 10. Then, if
797
801
the target record is even, throw an exception, which
798
802
should then roll the transaction back. The odd rows always
@@ -807,24 +811,26 @@ def test_thread_safety_of_transactions(self):
807
811
connect ("mongoenginetest" )
808
812
809
813
class A (Document ):
810
- i = IntField ()
814
+ i = IntField (unique = True )
811
815
812
816
A .drop_collection ()
813
817
# Ensure the collection is created
814
- A .objects .create (i = 0 )
818
+ _ = A .objects .first ()
819
+
820
+ thread_count = 20
815
821
816
822
def thread_fn (idx ):
817
823
# Open the transaction at some unknown interval
818
824
time .sleep (random .uniform (0.1 , 0.5 ))
819
825
try :
820
826
with run_in_transaction ():
821
827
a = A .objects .get (i = idx )
822
- a .i = idx * 10
828
+ a .i = idx * thread_count
823
829
# Save at some unknown interval
824
830
time .sleep (random .uniform (0.1 , 0.5 ))
825
831
a .save ()
826
832
827
- # Force roll backs for the even runs...
833
+ # Force rollbacks for the even runs...
828
834
if idx % 2 == 0 :
829
835
raise TestRollbackError ()
830
836
except TestRollbackError :
@@ -841,6 +847,7 @@ def thread_fn(idx):
841
847
"""
842
848
error_labels = op_failure .details .get ("errorLabels" , [])
843
849
if "TransientTransactionError" in error_labels :
850
+ logging .warning ("TransientTransactionError - retrying..." )
844
851
thread_fn (idx )
845
852
else :
846
853
raise op_failure
@@ -854,7 +861,6 @@ def thread_fn(idx):
854
861
A .objects .all ().delete ()
855
862
856
863
# Prepopulate the data for reads
857
- thread_count = 20
858
864
for i in range (thread_count ):
859
865
A .objects .create (i = i )
860
866
@@ -868,13 +874,10 @@ def thread_fn(idx):
868
874
t .join ()
869
875
870
876
# Check the sum
871
- expected_sum = 0
872
- for i in range (thread_count ):
873
- if i % 2 == 0 :
874
- expected_sum += i
875
- else :
876
- expected_sum += i * 10
877
- assert expected_sum == 1090
877
+ expected_sum = sum (
878
+ i if i % 2 == 0 else i * thread_count for i in range (thread_count )
879
+ )
880
+ assert expected_sum == 2090
878
881
assert expected_sum == sum (a .i for a in A .objects .all ())
879
882
880
883
0 commit comments