diff --git a/.appveyor.yml b/.appveyor.yml index e273845..cb8a7b7 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,7 +1,7 @@ environment: matrix: - julia_version: 1.0 - - julia_version: 1.4 + - julia_version: 1 - julia_version: nightly platform: diff --git a/.travis.yml b/.travis.yml index c16d287..8551e50 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,7 @@ os: - osx julia: - 1.0 - - 1.4 + - 1 - nightly env: - JULIA_NUM_THREADS=1 diff --git a/Project.toml b/Project.toml index 834f958..855265e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "LRUCache" uuid = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" -version = "1.1.0" +version = "1.2.0" [compat] julia = "1" diff --git a/src/LRUCache.jl b/src/LRUCache.jl index 6d17a93..45e085e 100644 --- a/src/LRUCache.jl +++ b/src/LRUCache.jl @@ -63,14 +63,15 @@ function Base.get(lru::LRU, key, default) end end function Base.get(default::Callable, lru::LRU, key) - lock(lru.lock) do - if _unsafe_haskey(lru, key) - v = _unsafe_getindex(lru, key) - return v - else - return default() - end + lock(lru.lock) + if _unsafe_haskey(lru, key) + v = _unsafe_getindex(lru, key) + unlock(lru.lock) + return v + else + unlock(lru.lock) end + return default() end function Base.get!(lru::LRU, key, default) lock(lru.lock) do @@ -85,16 +86,25 @@ function Base.get!(lru::LRU, key, default) end end function Base.get!(default::Callable, lru::LRU, key) - lock(lru.lock) do - if _unsafe_haskey(lru, key) - v = _unsafe_getindex(lru, key) - return v - end - v = default() + lock(lru.lock) + if _unsafe_haskey(lru, key) + v = _unsafe_getindex(lru, key) + unlock(lru.lock) + return v + else + unlock(lru.lock) + end + v = default() + lock(lru.lock) + if _unsafe_haskey(lru, key) + # should we test that this yields the same result as default() + v = _unsafe_getindex(lru, key) + else _unsafe_addindex!(lru, v, key) _unsafe_resize!(lru) - return v end + unlock(lru.lock) + return v end function _unsafe_getindex(lru::LRU, key) diff --git a/test/runtests.jl b/test/runtests.jl index 4da663f..a363602 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,7 +12,7 @@ using Base.Threads end @test collect(cache) == collect(i=>i for i in r) - @threads for i = 1:10:100 + for i = 1:10:100 @test haskey(cache, i) @test !haskey(cache, 100+i) end @@ -137,4 +137,20 @@ end @test_throws KeyError getindex(cache, p10[1]) end +@testset "Recursive lock in get(!)" begin + cache = LRU{Int,Int}(; maxsize = 100) + p = randperm(100) + cache[1] = 1 + + f!(cache, i) = get!(()->(f!(cache, i-1) + 1), cache, i) + @threads for i = 1:100 + f!(cache, p[i]) + end + + @threads for i = 1:100 + @test haskey(cache, i) + @test cache[i] == i + end +end + include("originaltests.jl")