Skip to content

Commit ba251e8

Browse files
authored
Fix sorting bugs (esp MissingOptimization) that come up when using SortingAlgorithms.TimSort (#50171)
1 parent c5b0a6c commit ba251e8

File tree

2 files changed

+54
-14
lines changed

2 files changed

+54
-14
lines changed

base/sort.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export # not exported by Base
4444
SMALL_ALGORITHM,
4545
SMALL_THRESHOLD
4646

47+
abstract type Algorithm end
4748

4849
## functions requiring only ordering ##
4950

@@ -436,7 +437,7 @@ for (sym, exp, type) in [
436437
(:mn, :(throw(ArgumentError("mn is needed but has not been computed"))), :(eltype(v))),
437438
(:mx, :(throw(ArgumentError("mx is needed but has not been computed"))), :(eltype(v))),
438439
(:scratch, nothing, :(Union{Nothing, Vector})), # could have different eltype
439-
(:allow_legacy_dispatch, true, Bool)]
440+
(:legacy_dispatch_entry, nothing, Union{Nothing, Algorithm})]
440441
usym = Symbol(:_, sym)
441442
@eval function $usym(v, o, kw)
442443
# using missing instead of nothing because scratch could === nothing.
@@ -499,8 +500,6 @@ internal or recursive calls.
499500
"""
500501
function _sort! end
501502

502-
abstract type Algorithm end
503-
504503

505504
"""
506505
MissingOptimization(next) <: Algorithm
@@ -524,12 +523,12 @@ struct WithoutMissingVector{T, U} <: AbstractVector{T}
524523
new{nonmissingtype(eltype(data)), typeof(data)}(data)
525524
end
526525
end
527-
Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i)
526+
Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i::Integer)
528527
out = v.data[i]
529528
@assert !(out isa Missing)
530529
out::eltype(v)
531530
end
532-
Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i)
531+
Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i::Integer)
533532
v.data[i] = x
534533
v
535534
end
@@ -590,8 +589,9 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw)
590589
# we can assume v is equal to eachindex(o.data) which allows a copying partition
591590
# without allocations.
592591
lo_i, hi_i = lo, hi
593-
for i in eachindex(o.data) # equal to copy(v)
594-
x = o.data[i]
592+
cv = eachindex(o.data) # equal to copy(v)
593+
for i in lo:hi
594+
x = o.data[cv[i]]
595595
if ismissing(x) == (o.order == Reverse) # should x go at the beginning/end?
596596
v[lo_i] = i
597597
lo_i += 1
@@ -2149,25 +2149,25 @@ end
21492149
# Support 3-, 5-, and 6-argument versions of sort! for calling into the internals in the old way
21502150
sort!(v::AbstractVector, a::Algorithm, o::Ordering) = sort!(v, firstindex(v), lastindex(v), a, o)
21512151
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering)
2152-
_sort!(v, a, o, (; lo, hi, allow_legacy_dispatch=false))
2152+
_sort!(v, a, o, (; lo, hi, legacy_dispatch_entry=a))
21532153
v
21542154
end
21552155
sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, _) = sort!(v, lo, hi, a, o)
21562156
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, scratch::Vector)
2157-
_sort!(v, a, o, (; lo, hi, scratch, allow_legacy_dispatch=false))
2157+
_sort!(v, a, o, (; lo, hi, scratch, legacy_dispatch_entry=a))
21582158
v
21592159
end
21602160

21612161
# Support dispatch on custom algorithms in the old way
21622162
# sort!(::AbstractVector, ::Integer, ::Integer, ::MyCustomAlgorithm, ::Ordering) = ...
21632163
function _sort!(v::AbstractVector, a::Algorithm, o::Ordering, kw)
2164-
@getkw lo hi scratch allow_legacy_dispatch
2165-
if allow_legacy_dispatch
2164+
@getkw lo hi scratch legacy_dispatch_entry
2165+
if legacy_dispatch_entry === a
2166+
# This error prevents infinite recursion for unknown algorithms
2167+
throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o)), ::Any) is not defined"))
2168+
else
21662169
sort!(v, lo, hi, a, o)
21672170
scratch
2168-
else
2169-
# This error prevents infinite recursion for unknown algorithms
2170-
throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o))) is not defined"))
21712171
end
21722172
end
21732173

test/sorting.jl

+40
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,46 @@ Base.similar(A::MyArray49392, ::Type{T}, dims::Dims{N}) where {T, N} = MyArray49
10251025
@test all(sort!(y, dims=2) .== sort!(x,dims=2))
10261026
end
10271027

1028+
@testset "MissingOptimization fastpath for Perm ordering when lo:hi ≠ eachindex(v)" begin
1029+
v = [rand() < .5 ? missing : rand() for _ in 1:100]
1030+
ix = collect(1:100)
1031+
sort!(ix, 1, 10, Base.Sort.DEFAULT_STABLE, Base.Order.Perm(Base.Order.Forward, v))
1032+
@test issorted(v[ix[1:10]])
1033+
end
1034+
1035+
struct NonScalarIndexingOfWithoutMissingVectorAlg <: Base.Sort.Algorithm end
1036+
function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlg, o::Base.Order.Ordering, kw)
1037+
Base.Sort.@getkw lo hi
1038+
first_half = v[lo:lo+(hi-lo)÷2]
1039+
second_half = v[lo+(hi-lo)÷2+1:hi]
1040+
whole = v[lo:hi]
1041+
all(vcat(first_half, second_half) .=== whole) || error()
1042+
out = Base.Sort._sort!(whole, Base.Sort.DEFAULT_STABLE, o, (;kw..., lo=1, hi=length(whole)))
1043+
v[lo:hi] .= whole
1044+
out
1045+
end
1046+
1047+
@testset "Non-scaler indexing of WithoutMissingVector" begin
1048+
@testset "Unit test" begin
1049+
wmv = Base.Sort.WithoutMissingVector(Union{Missing, Int}[1, 7, 2, 9])
1050+
@test wmv[[1, 3]] == [1, 2]
1051+
@test wmv[1:3] == [1, 7, 2]
1052+
end
1053+
@testset "End to end" begin
1054+
alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlg())
1055+
@test issorted(sort(rand(100); alg))
1056+
@test issorted(sort([rand() < .5 ? missing : randstring() for _ in 1:100]; alg))
1057+
end
1058+
end
1059+
1060+
struct DispatchLoopTestAlg <: Base.Sort.Algorithm end
1061+
function Base.sort!(v::AbstractVector, lo::Integer, hi::Integer, ::DispatchLoopTestAlg, order::Base.Order.Ordering)
1062+
sort!(view(v, lo:hi); order)
1063+
end
1064+
@testset "Support dispatch from the old style to the new style and back" begin
1065+
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
1066+
end
1067+
10281068
# This testset is at the end of the file because it is slow.
10291069
@testset "searchsorted" begin
10301070
numTypes = [ Int8, Int16, Int32, Int64, Int128,

0 commit comments

Comments
 (0)