diff --git a/src/NaNMath.jl b/src/NaNMath.jl index e846f83..e9f0361 100644 --- a/src/NaNMath.jl +++ b/src/NaNMath.jl @@ -14,7 +14,8 @@ end # Would be more efficient to remove the domain check in Base.sqrt(), # but this doesn't seem easy to do. -sqrt(x::Real) = x < 0.0 ? NaN : Base.sqrt(x) +sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x) +sqrt(x::Real) = sqrt(float(x)) # Don't override built-in ^ operator pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x, y) diff --git a/test/runtests.jl b/test/runtests.jl index a67b601..6e34c36 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,15 @@ using Test @test NaNMath.pow(-1.5,2.3) isa Float64 @test isnan(NaNMath.sqrt(-5)) @test NaNMath.sqrt(5) == Base.sqrt(5) +@test isnan(NaNMath.sqrt(-3.2f0)) && NaNMath.sqrt(-3.2f0) isa Float32 +@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat +@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64 +@inferred NaNMath.sqrt(5) +@inferred NaNMath.sqrt(5.0) +@inferred NaNMath.sqrt(5.0f0) +@inferred NaNMath.sqrt(-5) +@inferred NaNMath.sqrt(-5.0) +@inferred NaNMath.sqrt(-5.0f0) @test NaNMath.sum([1., 2., NaN]) == 3.0 @test NaNMath.sum([1. 2.; NaN 1.]) == 4.0 @test isnan(NaNMath.sum([NaN, NaN]))