@@ -634,6 +634,7 @@ def m1(i: int):
634
634
635
635
def run_thread (cache , nthreads ):
636
636
637
+ # direct tests
637
638
cache .clear ()
638
639
cache .reset ()
639
640
assert len (cache ) == 0
@@ -651,6 +652,7 @@ def banged(s: str, n: int) -> str:
651
652
assert banged ("a" , 3 ) == "aaa!" # hit
652
653
assert cache .hits () == 0.5
653
654
655
+ # threaded tests
654
656
cache .clear ()
655
657
cache .reset ()
656
658
assert len (cache ) == 0
@@ -660,6 +662,8 @@ def banged(s: str, n: int) -> str:
660
662
LS , LI = ["a" , "b" , "c" , "d" ], list (range (4 ))
661
663
662
664
def run ():
665
+ name = threading .current_thread ().name
666
+ log .debug (f"thread start: { name } " )
663
667
ls , li = LS .copy (), LI .copy ()
664
668
random .shuffle (ls )
665
669
random .shuffle (li )
@@ -669,17 +673,23 @@ def run():
669
673
barrier .wait ()
670
674
assert banged (s , n ) == s * n + "!"
671
675
barrier .wait ()
676
+ log .debug (f"thread end: { name } " )
672
677
673
678
threads = [ threading .Thread (target = run , name = f"thread { i } " ) for i in range (nthreads ) ]
674
679
list (map (lambda t : t .start (), threads ))
675
680
list (map (lambda t : t .join (), threads ))
676
681
682
+ # log.info(f"nthreads={nthreads} stats={cache.stats()}")
683
+
677
684
assert len (cache ) == 32
678
- assert abs (cache .hits () - (1.0 / (nthreads + 1 ))) < 0.001
685
+ # FIXME the hit ratio may not be deterministic?
686
+ # 16 * 2 gets-no-hit + 16 * (nthreads - 1) get-with-hit
687
+ hits = (nthreads - 1 ) / (nthreads + 1 )
688
+ assert cache .hits () == hits
679
689
680
690
def test_threads ():
681
- cache = ctu .StatsCache (ctu .DictCache ())
682
- run_thread (cache , 2 )
691
+ cache = ctu .LockedCache ( ctu . StatsCache (ctu .DictCache ()), threading . RLock ())
692
+ run_thread (cache , 4 )
683
693
del cache
684
694
685
695
def test_nogil ():
0 commit comments