Skip to content

Commit d766aee

Browse files
Merge pull request #71 from YingboMa/myb/fallback
Add generic fallback to all scalar functions
2 parents b351b97 + cf61ad0 commit d766aee

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/NaNMath.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,20 @@ for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
99
($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
1010
($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
1111
($f)(x::Real) = ($f)(float(x))
12+
if $f !== :lgamma
13+
($f)(x) = (Base.$f)(x)
14+
end
1215
end
1316
end
1417

18+
for f in (:sqrt,)
19+
@eval ($f)(x) = (Base.$f)(x)
20+
end
21+
22+
for f in (:max, :min)
23+
@eval ($f)(x, y) = (Base.$f)(x, y)
24+
end
25+
1526
# Would be more efficient to remove the domain check in Base.sqrt(),
1627
# but this doesn't seem easy to do.
1728
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
@@ -22,11 +33,13 @@ pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x,
2233
pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y)
2334
# We `promote` first before converting to floating pointing numbers to ensure that
2435
# e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)`
25-
pow(x::Number, y::Number) = pow(promote(x, y)...)
26-
pow(x::T, y::T) where {T<:Number} = pow(float(x), float(y))
36+
pow(x::Real, y::Real) = pow(promote(x, y)...)
37+
pow(x::T, y::T) where {T<:Real} = pow(float(x), float(y))
38+
pow(x, y) = ^(x, y)
2739

2840
# The following combinations are safe, so we can fall back to ^
2941
pow(x::Number, y::Integer) = x^y
42+
pow(x::Real, y::Integer) = x^y
3043
pow(x::Complex, y::Complex) = x^y
3144

3245
"""

test/runtests.jl

+15
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,18 @@ using Test
8383
@test isnan(NaNMath.max(NaN, NaN))
8484
@test isnan(NaNMath.max(NaN))
8585
@test NaNMath.max(NaN, NaN, 0.0, 1.0) == 1.0
86+
87+
# Test forwarding
88+
x = 1 + 2im
89+
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
90+
:log1p, :sqrt)
91+
@test @eval (NaNMath.$f)(x) == $f(x)
92+
end
93+
94+
struct A end
95+
Base.isless(::A, ::A) = false
96+
y = A()
97+
for f in (:max, :min)
98+
@test @eval (NaNMath.$f)(y, y) == $f(y, y)
99+
end
100+
@test NaNMath.pow(x, x) == ^(x, x)

0 commit comments

Comments
 (0)