Skip to content

Commit f8ef8df

Browse files
KenoKristofferC
authored andcommitted
staticdata: Close data race after backedge insertion (#57229)
Addresses review comment in #57212 (comment). The key is that the hand-off of responsibility for verification between the loading code and the ordinary backedge mechanism happens under the world counter lock to ensure synchronization. (cherry picked from commit 34aceb5)
1 parent 7825364 commit f8ef8df

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

base/staticdata.jl

+16-12
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,16 @@ end
3737
function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, external::Bool=false)
3838
for i = 1:length(edges)
3939
codeinst = edges[i]::CodeInstance
40-
verify_method_graph(codeinst, stack, visiting, mwis)
40+
validation_world = get_world_counter()
41+
verify_method_graph(codeinst, stack, visiting, mwis, validation_world)
42+
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
43+
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
44+
# validity.
45+
@ccall jl_promote_ci_to_current(codeinst::Any, validation_world::UInt)::Cvoid
4146
minvalid = codeinst.min_world
4247
maxvalid = codeinst.max_world
48+
# Finally, if this CI is still valid in some world age and and belongs to an external method(specialization),
49+
# poke it that mi's cache
4350
if maxvalid minvalid && external
4451
caller = get_ci_mi(codeinst)
4552
@assert isdefined(codeinst, :inferred) # See #53586, #53109
@@ -55,9 +62,9 @@ function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visi
5562
end
5663
end
5764

58-
function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method})
65+
function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, validation_world::UInt)
5966
@assert isempty(stack); @assert isempty(visiting);
60-
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, mwis)
67+
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, mwis, validation_world)
6168
@assert child_cycle == 0
6269
@assert isempty(stack); @assert isempty(visiting);
6370
nothing
@@ -67,15 +74,14 @@ end
6774
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
6875
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
6976
# and slightly modified with an early termination option once the computation reaches its minimum
70-
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method})
77+
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, mwis::IdSet{Method}, validation_world::UInt)
7178
world = codeinst.min_world
7279
let max_valid2 = codeinst.max_world
7380
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
7481
return 0, world, max_valid2
7582
end
7683
end
77-
current_world = get_world_counter()
78-
local minworld::UInt, maxworld::UInt = 1, current_world
84+
local minworld::UInt, maxworld::UInt = 1, validation_world
7985
def = get_ci_mi(codeinst).def
8086
@assert def isa Method
8187
if haskey(visiting, codeinst)
@@ -177,7 +183,7 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
177183
end
178184
callee = edge
179185
local min_valid2::UInt, max_valid2::UInt
180-
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, mwis)
186+
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, mwis, validation_world)
181187
if minworld < min_valid2
182188
minworld = min_valid2
183189
end
@@ -209,16 +215,14 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
209215
if maxworld 0
210216
@atomic :monotonic child.min_world = minworld
211217
end
212-
if maxworld == current_world
218+
@atomic :monotonic child.max_world = maxworld
219+
if maxworld == validation_world && validation_world == get_world_counter()
213220
Base.Compiler.store_backedges(child, child.edges)
214-
@atomic :monotonic child.max_world = typemax(UInt)
215-
else
216-
@atomic :monotonic child.max_world = maxworld
217221
end
218222
@assert visiting[child] == length(stack) + 1
219223
delete!(visiting, child)
220224
invalidations = _jl_debug_method_invalidation[]
221-
if invalidations !== nothing && maxworld < current_world
225+
if invalidations !== nothing && maxworld < validation_world
222226
push!(invalidations, child, "verify_methods", cause)
223227
end
224228
end

src/staticdata.c

+28
Original file line numberDiff line numberDiff line change
@@ -4380,6 +4380,34 @@ JL_DLLEXPORT jl_value_t *jl_restore_package_image_from_file(const char *fname, j
43804380
return mod;
43814381
}
43824382

4383+
JL_DLLEXPORT void _jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world) JL_NOTSAFEPOINT
4384+
{
4385+
if (jl_atomic_load_relaxed(&ci->max_world) != validated_world)
4386+
return;
4387+
jl_atomic_store_relaxed(&ci->max_world, ~(size_t)0);
4388+
jl_svec_t *edges = jl_atomic_load_relaxed(&ci->edges);
4389+
for (size_t i = 0; i < jl_svec_len(edges); i++) {
4390+
jl_value_t *edge = jl_svecref(edges, i);
4391+
if (!jl_is_code_instance(edge))
4392+
continue;
4393+
_jl_promote_ci_to_current(ci, validated_world);
4394+
}
4395+
}
4396+
4397+
JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world)
4398+
{
4399+
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
4400+
// No need to acquire the lock if we've been invalidated anyway
4401+
if (current_world > validated_world)
4402+
return;
4403+
JL_LOCK(&world_counter_lock);
4404+
current_world = jl_atomic_load_relaxed(&jl_world_counter);
4405+
if (current_world == validated_world) {
4406+
_jl_promote_ci_to_current(ci, validated_world);
4407+
}
4408+
JL_UNLOCK(&world_counter_lock);
4409+
}
4410+
43834411
#ifdef __cplusplus
43844412
}
43854413
#endif

0 commit comments

Comments
 (0)