Skip to content

Work around Julia's Base.Sort.MissingOptimization bugs #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/SortingAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ struct TimSortAlg <: Algorithm end
struct RadixSortAlg <: Algorithm end
struct CombSortAlg <: Algorithm end

function maybe_optimize(x::Algorithm)
function maybe_optimize(x::Algorithm)
isdefined(Base.Sort, :InitialOptimizations) ? Base.Sort.InitialOptimizations(x) : x
end
end
const HeapSort = maybe_optimize(HeapSortAlg())
const TimSort = maybe_optimize(TimSortAlg())
# Whenever InitialOptimizations is defined, RadixSort falls
# Whenever InitialOptimizations is defined, RadixSort falls
# back to Base.DEFAULT_STABLE which already includes them.
const RadixSort = RadixSortAlg()

Expand Down Expand Up @@ -79,6 +79,27 @@ end
#
# Original author: @kmsquire

@static if v"1.9.0-alpha" <= VERSION <= v"1.9.1"
function Base.getindex(v::Base.Sort.WithoutMissingVector, i::UnitRange)
out = Vector{eltype(v)}(undef, length(i))
out .= v.data[i]
out
end

# skip MissingOptimization due to JuliaLang/julia#50171
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE.next

# Explicitly define conversion from _sort!(v, alg, order, kw) to sort!(v, lo, hi, alg, order)
# To avoid excessively strict dispatch loop detection
function Base.Sort._sort!(v::AbstractVector, a::Union{HeapSortAlg, TimSortAlg, RadixSortAlg, CombSortAlg}, o::Base.Order.Ordering, kw)
Base.Sort.@getkw lo hi scratch
sort!(v, lo, hi, a, o)
scratch
end
else
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE
end

const Run = UnitRange{Int}

const MIN_GALLOP = 7
Expand Down Expand Up @@ -490,7 +511,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, ::TimSortAlg, o::Ordering)
# Make a run of length minrun
count = min(minrun, hi-i+1)
run_range = i:i+count-1
sort!(v, i, i+count-1, DEFAULT_STABLE, o)
sort!(v, i, i+count-1, _FIVE_ARG_SAFE_DEFAULT_STABLE, o)
else
if !issorted(run_range)
run_range = last(run_range):first(run_range)
Expand Down
29 changes: 25 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ using StatsBase
using Random

a = rand(1:10000, 1000)
am = [rand() < .9 ? i : missing for i in a]

for alg in [TimSort, HeapSort, RadixSort, CombSort]
for alg in [TimSort, HeapSort, RadixSort, CombSort, SortingAlgorithms.TimSortAlg()]
b = sort(a, alg=alg)
@test issorted(b)
ix = sortperm(a, alg=alg)
b = a[ix]
@test issorted(b)
@test a[ix] == b

# legacy 3-argument calling convention
@test b == sort!(copy(a), alg, Base.Order.Forward)

b = sort(a, alg=alg, rev=true)
@test issorted(b, rev=true)
ix = sortperm(a, alg=alg, rev=true)
Expand All @@ -34,9 +38,26 @@ for alg in [TimSort, HeapSort, RadixSort, CombSort]
invpermute!(c, ix)
@test c == a

if alg != RadixSort # RadixSort does not work with Lt orderings
if alg != RadixSort # RadixSort does not work with Lt orderings or missing
c = sort(a, alg=alg, lt=(>))
@test b == c

# Issue https://github.com/JuliaData/DataFrames.jl/issues/3340
bm1 = sort(am, alg=alg)
@test issorted(bm1)
@test count(ismissing, bm1) == count(ismissing, am)

bm2 = am[sortperm(am, alg=alg)]
@test issorted(bm2)
@test count(ismissing, bm2) == count(ismissing, am)

bm3 = am[sortperm!(collect(eachindex(am)), am, alg=alg)]
@test issorted(bm3)
@test count(ismissing, bm3) == count(ismissing, am)

if alg == TimSort # Stable
@test all(bm1 .=== bm2 .=== bm3)
end
end

c = sort(a, alg=alg, by=x->1/x)
Expand Down Expand Up @@ -103,8 +124,8 @@ for n in [0:10..., 100, 101, 1000, 1001]
# test float sorting with NaNs
s = sort(v, alg=alg, order=ord)
@test issorted(s, order=ord)
# This tests that NaNs (which compare equivalent) are treated stably

# This tests that NaNs (which compare equivalent) are treated stably
# even when the underlying algorithm is unstable. That it happens to
# pass is not a part of the public API:
@test reinterpret(UInt64, v[map(isnan, v)]) == reinterpret(UInt64, s[map(isnan, s)])
Expand Down