Skip to content

Commit f503808

Browse files
committed
improve concurrency safety for Compiler.finish!
Similar to #57229, this commit ensures that `Compiler.finish!` properly synchronizes the operations to set `max_world` for cached `CodeInstance`s by holding the world counter lock. Previously, `Compiler.finish!` relied on a narrow timing window to avoid race conditions, which was not a robust approach in a concurrent execution environment. This change ensures that `Compiler.finish!` holds the appropriate lock (via `jl_promote_ci_to_current`).
1 parent b65f004 commit f503808

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

Compiler/src/typeinfer.jl

+20-9
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ If set to `true`, record per-method-instance timings within type inference in th
9292
__set_measure_typeinf(onoff::Bool) = __measure_typeinf__[] = onoff
9393
const __measure_typeinf__ = RefValue{Bool}(false)
9494

95-
function finish!(interp::AbstractInterpreter, caller::InferenceState)
95+
function finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
9696
result = caller.result
9797
opt = result.src
9898
if opt isa OptimizationState
@@ -108,12 +108,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
108108
ci = result.ci
109109
# if we aren't cached, we don't need this edge
110110
# but our caller might, so let's just make it anyways
111-
if last(result.valid_worlds) >= get_world_counter()
112-
# TODO: this should probably come after all store_backedges (after optimizations) for the entire graph in finish_cycle
113-
# since we should be requiring that all edges first get their backedges set, as a batch
114-
result.valid_worlds = WorldRange(first(result.valid_worlds), typemax(UInt))
115-
end
116-
if last(result.valid_worlds) == typemax(UInt)
111+
if last(result.valid_worlds) >= validation_world
117112
# if we can record all of the backedges in the global reverse-cache,
118113
# we can now widen our applicability in the global cache too
119114
store_backedges(ci, edges)
@@ -202,7 +197,14 @@ function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
202197
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
203198
optimize(frame.interp, opt, frame.result)
204199
end
205-
finish!(frame.interp, frame)
200+
validation_world = get_world_counter()
201+
finish!(frame.interp, frame, validation_world)
202+
if isdefined(frame.result, :ci)
203+
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
204+
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
205+
# validity.
206+
ccall(:jl_promote_ci_to_current, Cvoid, (Any, UInt), frame.result.ci, validation_world)
207+
end
206208
if frame.cycleid != 0
207209
frames = frame.callstack::Vector{AbsIntState}
208210
@assert frames[end] === frame
@@ -236,10 +238,19 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
236238
optimize(caller.interp, opt, caller.result)
237239
end
238240
end
241+
validation_world = get_world_counter()
242+
cis = CodeInstance[]
239243
for frameid = cycleid:length(frames)
240244
caller = frames[frameid]::InferenceState
241-
finish!(caller.interp, caller)
245+
finish!(caller.interp, caller, validation_world)
246+
if isdefined(caller.result, :ci)
247+
push!(cis, caller.result.ci)
248+
end
242249
end
250+
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
251+
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
252+
# validity.
253+
ccall(:jl_promote_cis_to_current, Cvoid, (Ptr{CodeInstance}, Csize_t, UInt), cis, length(cis), validation_world)
243254
resize!(frames, cycleid - 1)
244255
return nothing
245256
end

src/staticdata.c

+16
Original file line numberDiff line numberDiff line change
@@ -4408,6 +4408,22 @@ JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t valida
44084408
JL_UNLOCK(&world_counter_lock);
44094409
}
44104410

4411+
JL_DLLEXPORT void jl_promote_cis_to_current(jl_code_instance_t **cis, size_t n, size_t validated_world)
4412+
{
4413+
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
4414+
// No need to acquire the lock if we've been invalidated anyway
4415+
if (current_world > validated_world)
4416+
return;
4417+
JL_LOCK(&world_counter_lock);
4418+
current_world = jl_atomic_load_relaxed(&jl_world_counter);
4419+
if (current_world == validated_world) {
4420+
for (size_t i = 0; i < n; i++) {
4421+
_jl_promote_ci_to_current(cis[i], validated_world);
4422+
}
4423+
}
4424+
JL_UNLOCK(&world_counter_lock);
4425+
}
4426+
44114427
#ifdef __cplusplus
44124428
}
44134429
#endif

0 commit comments

Comments
 (0)