-
Notifications
You must be signed in to change notification settings - Fork 1
Float32 loggamma #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,21 +5,6 @@ | |
| # See: D. E. G. Hare, "Computing the principal branch of log-Gamma," | ||
| # J. Algorithms 25, pp. 221-236 (1997) | ||
|
|
||
| const HALF_LOG2PI_F64 = 9.1893853320467274178032927e-01 | ||
| const LOGPI_F64 = 1.1447298858494002 | ||
| const TWO_PI_F64 = 6.2831853071795864769252842 | ||
|
|
||
| # Lanczos-type rational approximation for loggamma on (2, 3) | ||
| # Used as the core for reduction-based approach | ||
| const _LOGGAMMA_P = ( | ||
| -2.44167345903529816830968e-01, 6.73523010531981020863696e-02, | ||
| -2.05808084277845478790009e-02, 7.38555102867398526627303e-03, | ||
| -2.89051033074153369901384e-03, 1.19275391170326097711398e-03, | ||
| -5.09669524743042422335582e-04, 2.23154759903498081132513e-04, | ||
| -9.94575127818085337147321e-05, 4.49262367382046739858373e-05, | ||
| -2.05077312586603517590604e-05 | ||
| ) | ||
|
|
||
| """ | ||
| loggamma(x::Real) | ||
|
|
||
|
|
@@ -32,15 +17,17 @@ but may differ from `log(gamma(x))` by an integer multiple of ``2\\pi i``. | |
| External links: [DLMF](https://dlmf.nist.gov/5.4), [Wikipedia](https://en.wikipedia.org/wiki/Gamma_function#The_log-gamma_function) | ||
| """ | ||
| loggamma(x::Float64) = _loggamma(x) | ||
| loggamma(x::Union{Float16, Float32}) = typeof(x)(_loggamma(Float64(x))) | ||
| loggamma(x::Float32) = _loggamma(x) | ||
| loggamma(x::Float16) = Float16(_loggamma(Float32(x))) | ||
| loggamma(x::Rational) = loggamma(float(x)) | ||
| loggamma(x::Integer) = loggamma(float(x)) | ||
| loggamma(z::Complex{Float64}) = _loggamma(z) | ||
| loggamma(z::Complex{Float32}) = Complex{Float32}(_loggamma(Complex{Float64}(z))) | ||
| loggamma(z::Complex{Float16}) = Complex{Float16}(_loggamma(Complex{Float64}(z))) | ||
| loggamma(z::Complex{Float32}) = _loggamma(z) | ||
| loggamma(z::Complex{Float16}) = Complex{Float16}(_loggamma(Complex{Float32}(z))) | ||
| loggamma(z::Complex{<:Integer}) = _loggamma(Complex{Float64}(z)) | ||
| loggamma(z::Complex{<:Rational}) = loggamma(float(z)) | ||
| function loggamma(x::BigFloat) | ||
| # For now we use the same implementation for BigFloat as Complex{BigFloat}. This is not ideal since it does more work than necessary. | ||
| if isnan(x) | ||
| return x | ||
| elseif isinf(x) | ||
|
|
@@ -68,20 +55,57 @@ logfactorial(x::Integer) = x < 0 ? throw(DomainError(x, "`x` must be non-negativ | |
|
|
||
| Returns a tuple `(log(abs(Γ(x))), sign(Γ(x)))` for real `x`. | ||
| """ | ||
| logabsgamma(x::Float32) = _logabsgamma(x) | ||
| logabsgamma(x::Real) = _logabsgamma(float(x)) | ||
| function logabsgamma(x::Float16) | ||
| y, s = _logabsgamma(Float32(x)) | ||
| return Float16(y), s | ||
| end | ||
|
|
||
|
|
||
| #################################### | ||
| ## Float64 loggamma implementation | ||
| #################################### | ||
|
|
||
| const HALF_LOG2PI_F64 = 9.1893853320467274178032927e-01 | ||
| const LOGPI_F64 = 1.1447298858494002 | ||
| const TWO_PI_F64 = 6.2831853071795864769252842 | ||
|
|
||
| # Taylor series coefficients for complex loggamma at z=1 and z=2 (Float64) | ||
| const _TAYLOR1_64 = ( | ||
| -5.7721566490153286060651188e-01, 8.2246703342411321823620794e-01, | ||
| -4.0068563438653142846657956e-01, 2.705808084277845478790009e-01, | ||
| -2.0738555102867398526627303e-01, 1.6955717699740818995241986e-01, | ||
| -1.4404989676884611811997107e-01, 1.2550966952474304242233559e-01, | ||
| -1.1133426586956469049087244e-01, 1.000994575127818085337147e-01, | ||
| -9.0954017145829042232609344e-02, 8.3353840546109004024886499e-02, | ||
| -7.6932516411352191472827157e-02, 7.1432946295361336059232779e-02, | ||
| -6.6668705882420468032903454e-02 | ||
| ) | ||
|
|
||
| const _TAYLOR2_64 = ( | ||
| 4.2278433509846713939348812e-01, 3.2246703342411321823620794e-01, | ||
| -6.7352301053198095133246196e-02, 2.0580808427784547879000897e-02, | ||
| -7.3855510286739852662729527e-03, 2.8905103307415232857531201e-03, | ||
| -1.1927539117032609771139825e-03, 5.0966952474304242233558822e-04, | ||
| -2.2315475845357937976132853e-04, 9.9457512781808533714662972e-05, | ||
| -4.4926236738133141700224489e-05, 2.0507212775670691553131246e-05 | ||
| ) | ||
|
|
||
| # Stirling asymptotic series for log(Γ(x)), valid for x > 0 sufficiently large | ||
| # coefficients are bernoulli[2k] / (2k*(2k-1)) for k = 1,...,8 | ||
|
|
||
| const _STIRLING_COEFFS_64 = ( | ||
| 8.333333333333333333333368e-02, -2.777777777777777777777778e-03, | ||
| 7.936507936507936507936508e-04, -5.952380952380952380952381e-04, | ||
| 8.417508417508417508417510e-04, -1.917526917526917526917527e-03, | ||
| 6.410256410256410256410257e-03, -2.955065359477124183006535e-02 | ||
| ) | ||
| function _loggamma_stirling(x::Float64) | ||
| t = inv(x) | ||
| w = t * t | ||
| return muladd(x - 0.5, log(x), -x + HALF_LOG2PI_F64 + # log(2π)/2 | ||
| t * @evalpoly(w, | ||
| 8.333333333333333333333368e-02, -2.777777777777777777777778e-03, | ||
| 7.936507936507936507936508e-04, -5.952380952380952380952381e-04, | ||
| 8.417508417508417508417510e-04, -1.917526917526917526917527e-03, | ||
| 6.410256410256410256410257e-03, -2.955065359477124183006535e-02 | ||
| ) | ||
| t * @evalpoly(w, _STIRLING_COEFFS_64...) | ||
| ) | ||
| end | ||
|
|
||
|
|
@@ -90,12 +114,7 @@ function _loggamma_asymptotic(z::Complex{Float64}) | |
| zinv = inv(z) | ||
| t = zinv * zinv | ||
| return (z - 0.5) * log(z) - z + HALF_LOG2PI_F64 + # log(2π)/2 | ||
| zinv * @evalpoly(t, | ||
| 8.333333333333333333333368e-02, -2.777777777777777777777778e-03, | ||
| 7.936507936507936507936508e-04, -5.952380952380952380952381e-04, | ||
| 8.417508417508417508417510e-04, -1.917526917526917526917527e-03, | ||
| 6.410256410256410256410257e-03, -2.955065359477124183006535e-02 | ||
| ) | ||
| zinv * @evalpoly(t, _STIRLING_COEFFS_64...) | ||
| end | ||
|
|
||
| function _logabsgamma(x::Float64) | ||
|
|
@@ -123,32 +142,6 @@ function _logabsgamma_unsafe_sub0(x::Float64) | |
| return LOGPI_F64 - log(abs(s)) - _loggamma(1.0 - x), sgn | ||
| end | ||
|
|
||
| function _logabsgamma(x::Float32) | ||
| y, s = _logabsgamma(Float64(x)) | ||
| return Float32(y), s | ||
| end | ||
|
|
||
| function _logabsgamma(x::Float16) | ||
| y, s = _logabsgamma(Float64(x)) | ||
| return Float16(y), s | ||
| end | ||
|
|
||
| function _logabsgamma(x::BigFloat) | ||
| if isnan(x) | ||
| return x, 1 | ||
| elseif isinf(x) | ||
| return x > 0 ? (x, 1) : (BigFloat(NaN), 1) | ||
| elseif x > 0 | ||
| return real(_loggamma_complex_bigfloat(Complex{BigFloat}(x, zero(BigFloat)))), 1 | ||
| elseif iszero(x) | ||
| return BigFloat(Inf), Int(sign(1 / x)) | ||
| end | ||
|
|
||
| s = sinpi(x) | ||
| s == 0 && return BigFloat(Inf), 1 | ||
| return real(_loggamma_complex_bigfloat(Complex{BigFloat}(x, zero(BigFloat)))), (signbit(s) ? -1 : 1) | ||
| end | ||
|
|
||
| # loggamma for real Float64 | ||
| function _loggamma(x::Float64) | ||
| if isnan(x) | ||
|
|
@@ -193,7 +186,10 @@ function _loggamma_unsafe_pos(x::Float64) | |
| end | ||
| end | ||
|
|
||
| # Complex loggamma for Float64 | ||
|
|
||
| #################################### | ||
| ## Complex{Float64} loggamma implementation | ||
| #################################### | ||
| # Combines the asymptotic series, Taylor series at z=1 and z=2, | ||
| # the reflection formula, and the shift recurrence. | ||
| function _loggamma(z::Complex{Float64}) | ||
|
|
@@ -221,28 +217,12 @@ function _loggamma(z::Complex{Float64}) | |
| # Taylor series at z=1 | ||
| # coefficients: [-γ; [(-1)^k * ζ(k)/k for k in 2:15]] | ||
| w = Complex(x - 1, y) | ||
| return w * @evalpoly(w, | ||
| -5.7721566490153286060651188e-01, 8.2246703342411321823620794e-01, | ||
| -4.0068563438653142846657956e-01, 2.705808084277845478790009e-01, | ||
| -2.0738555102867398526627303e-01, 1.6955717699740818995241986e-01, | ||
| -1.4404989676884611811997107e-01, 1.2550966952474304242233559e-01, | ||
| -1.1133426586956469049087244e-01, 1.000994575127818085337147e-01, | ||
| -9.0954017145829042232609344e-02, 8.3353840546109004024886499e-02, | ||
| -7.6932516411352191472827157e-02, 7.1432946295361336059232779e-02, | ||
| -6.6668705882420468032903454e-02 | ||
| ) | ||
| return w * @evalpoly(w, _TAYLOR1_64...) | ||
| elseif abs(x - 2) + yabs < 0.1 | ||
| # Taylor series at z=2 | ||
| # coefficients: [1-γ; [(-1)^k * (ζ(k)-1)/k for k in 2:12]] | ||
| w = Complex(x - 2, y) | ||
| return w * @evalpoly(w, | ||
| 4.2278433509846713939348812e-01, 3.2246703342411321823620794e-01, | ||
| -6.7352301053198095133246196e-02, 2.0580808427784547879000897e-02, | ||
| -7.3855510286739852662729527e-03, 2.8905103307415232857531201e-03, | ||
| -1.1927539117032609771139825e-03, 5.0966952474304242233558822e-04, | ||
| -2.2315475845357937976132853e-04, 9.9457512781808533714662972e-05, | ||
| -4.4926236738133141700224489e-05, 2.0507212775670691553131246e-05 | ||
| ) | ||
| return w * @evalpoly(w, _TAYLOR2_64...) | ||
| else | ||
| # shift using recurrence: loggamma(z) = loggamma(z+n) - log(∏(z+k)) | ||
| shiftprod = Complex(x, yabs) | ||
|
|
@@ -266,7 +246,162 @@ function _loggamma(z::Complex{Float64}) | |
| end | ||
| end | ||
|
|
||
| # Complex BigFloat loggamma | ||
| #################################### | ||
| ## Float32 loggamma implementation | ||
| #################################### | ||
|
|
||
| const HALF_LOG2PI_F32 = 9.1893853320467274178032927f-01 | ||
| const LOGPI_F32 = 1.1447298858494002f0 | ||
| const TWO_PI_F32 = 6.2831853071795864769252842f0 | ||
|
|
||
| const _STIRLING_COEFFS_32 = ( | ||
| 8.333333333333333333333368f-02, -2.777777777777777777777778f-03, | ||
| 7.936507936507936507936508f-04, -5.952380952380952380952381f-04, | ||
| 8.417508417508417508417510f-04 | ||
| ) | ||
|
|
||
| const _TAYLOR1_32 = ( | ||
| -5.7721566490153286060651188f-01, 8.2246703342411321823620794f-01, | ||
| -4.0068563438653142846657956f-01, 2.705808084277845478790009f-01, | ||
| -2.0738555102867398526627303f-01, 1.6955717699740818995241986f-01, | ||
| -1.4404989676884611811997107f-01, 1.2550966952474304242233559f-01, | ||
| -1.1133426586956469049087244f-01, 1.000994575127818085337147f-01 | ||
| ) | ||
|
|
||
| const _TAYLOR2_32 = ( | ||
| 4.2278433509846713939348812f-01, 3.2246703342411321823620794f-01, | ||
| -6.7352301053198095133246196f-02, 2.0580808427784547879000897f-02, | ||
| -7.3855510286739852662729527f-03, 2.8905103307415232857531201f-03, | ||
| -1.1927539117032609771139825f-03, 5.0966952474304242233558822f-04 | ||
| ) | ||
|
|
||
| function _loggamma_stirling(x::Float32) | ||
| t = inv(x) | ||
| w = t * t | ||
| return muladd(x - 0.5f0, log(x), -x + HALF_LOG2PI_F32 + | ||
| t * @evalpoly(w, _STIRLING_COEFFS_32...) | ||
| ) | ||
| end | ||
|
|
||
| function _loggamma_unsafe_pos(x::Float32) | ||
| if x < 7 | ||
| n = 7 - floor(Int, x) | ||
| z = x | ||
| prod = one(x) | ||
| for i in 0:n-1 | ||
| prod *= z + i | ||
| end | ||
| return _loggamma_stirling(z + n) - log(prod) | ||
| else | ||
| return _loggamma_stirling(x) | ||
| end | ||
| end | ||
|
|
||
| function _logabsgamma_unsafe_sub0(x::Float32) | ||
| s = sinpi(x) | ||
| s == 0 && return Float32(Inf), 1 | ||
| sgn = signbit(s) ? -1 : 1 | ||
| return LOGPI_F32 - log(abs(s)) - _loggamma(1 - x), sgn | ||
| end | ||
|
|
||
| function _logabsgamma(x::Float32) | ||
| if isnan(x) | ||
| return x, 1 | ||
| elseif x > 0 | ||
| return _loggamma_unsafe_pos(x), 1 | ||
| elseif x == 0 | ||
| return Float32(Inf), Int(sign(1 / x)) | ||
| else | ||
| s = sinpi(x) | ||
| s == 0 && return Float32(Inf), 1 | ||
| sgn = signbit(s) ? -1 : 1 | ||
| return LOGPI_F32 - log(abs(s)) - _loggamma(1 - x), sgn | ||
| end | ||
| end | ||
|
|
||
| function _loggamma(x::Float32) | ||
| if isnan(x) | ||
| return x | ||
| elseif isinf(x) | ||
| return x > 0 ? Float32(Inf) : Float32(NaN) | ||
| elseif x ≤ 0 | ||
| x == 0 && return Float32(Inf) | ||
| isinteger(x) && return Float32(Inf) | ||
| y, sgn = _logabsgamma_unsafe_sub0(x) | ||
| sgn < 0 && throw(DomainError(x, "`gamma(x)` must be non-negative")) | ||
| return y | ||
| elseif x < 7 | ||
| n = 7 - floor(Int, x) | ||
| z = x | ||
| prod = one(x) | ||
| for i in 0:n-1 | ||
| prod *= z + i | ||
| end | ||
| return _loggamma_stirling(z + n) - log(prod) | ||
| else | ||
| return _loggamma_stirling(x) | ||
| end | ||
| end | ||
|
|
||
| function _loggamma_asymptotic(z::Complex{Float32}) | ||
| zinv = inv(z) | ||
| t = zinv * zinv | ||
| return (z - 0.5f0) * log(z) - z + HALF_LOG2PI_F32 + | ||
| zinv * @evalpoly(t, _STIRLING_COEFFS_32...) | ||
| end | ||
|
|
||
| function _loggamma(z::Complex{Float32}) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly to above, can this be unified with the other |
||
| x = real(z) | ||
| y = imag(z) | ||
| yabs = abs(y) | ||
|
|
||
| if !isfinite(x) || !isfinite(y) | ||
| if isinf(x) && isfinite(y) | ||
| return Complex{Float32}(x, x > 0 ? (iszero(y) ? y : copysign(Float32(Inf), y)) : copysign(Float32(Inf), -y)) | ||
| elseif isfinite(x) && isinf(y) | ||
| return Complex{Float32}(-Float32(Inf), y) | ||
| else | ||
| return Complex{Float32}(Float32(NaN), Float32(NaN)) | ||
| end | ||
| elseif x > 7 || yabs > 7 | ||
| return _loggamma_asymptotic(z) | ||
| elseif x < Float32(0.1) | ||
| if x == 0 && y == 0 | ||
| return Complex{Float32}(Float32(Inf), copysign(Float32(π), -y)) | ||
| end | ||
| return Complex(LOGPI_F32, copysign(TWO_PI_F32, y) * floor(0.5f0 * x + 0.25f0)) - | ||
| log(sinpi(z)) - _loggamma(Complex{Float32}(1 - x, -y)) | ||
| elseif abs(x - 1) + yabs < 0.1f0 | ||
| w = Complex{Float32}(x - 1, y) | ||
| return w * @evalpoly(w, _TAYLOR1_32...) | ||
| elseif abs(x - 2) + yabs < 0.1f0 | ||
| w = Complex{Float32}(x - 2, y) | ||
| return w * @evalpoly(w, _TAYLOR2_32...) | ||
| else | ||
| shiftprod = Complex{Float32}(x, yabs) | ||
| xshift = x + 1 | ||
| sb = false | ||
| signflips = 0 | ||
| while xshift ≤ 7 | ||
| shiftprod *= Complex{Float32}(xshift, yabs) | ||
| sbp = signbit(imag(shiftprod)) | ||
| signflips += sbp & (sbp != sb) | ||
| sb = sbp | ||
| xshift += 1 | ||
| end | ||
| shift = log(shiftprod) | ||
| if signbit(y) | ||
| shift = Complex(real(shift), signflips * -TWO_PI_F32 - imag(shift)) | ||
| else | ||
| shift = Complex(real(shift), imag(shift) + signflips * TWO_PI_F32) | ||
| end | ||
| return _loggamma_asymptotic(Complex{Float32}(xshift, y)) - shift | ||
| end | ||
| end | ||
|
|
||
| #################################### | ||
| ## Complex{BigFloat} loggamma implementation | ||
| #################################### | ||
| # Adapted from SpecialFunctions.jl (MIT license) | ||
| # Uses Stirling series with Bernoulli numbers computed via Akiyama-Tanigawa, | ||
| # reflection formula, upward recurrence, and branch correction via Float64 oracle. | ||
|
|
@@ -370,3 +505,19 @@ function _loggamma(z::Complex{BigFloat}) | |
| end | ||
| end | ||
| end | ||
|
|
||
| function _logabsgamma(x::BigFloat) | ||
| if isnan(x) | ||
| return x, 1 | ||
| elseif isinf(x) | ||
| return x > 0 ? (x, 1) : (BigFloat(NaN), 1) | ||
| elseif x > 0 | ||
| return real(_loggamma_complex_bigfloat(Complex{BigFloat}(x, zero(BigFloat)))), 1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any easy optimizations for real gamma for BigFloat? I'd think that at the very least the polynomials should be nicer...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, squeezing out optimal for bigfloats is hard but better than routing to complex bigfloats is easy. I'll give it a shot |
||
| elseif iszero(x) | ||
| return BigFloat(Inf), Int(sign(1 / x)) | ||
| end | ||
|
|
||
| s = sinpi(x) | ||
| s == 0 && return BigFloat(Inf), 1 | ||
| return real(_loggamma_complex_bigfloat(Complex{BigFloat}(x, zero(BigFloat)))), (signbit(s) ? -1 : 1) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify this (ideally with both the Float64 and BigFloat) nothing that we're doing here should be Float32 specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can, I split it because the implementations may diverge in the future but I'm happy to merge them with type based dispatch if you prefer that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bigfloat we can't, because the polynomial approximations that are hard coded don't scale to arbitrary precision, that's why different approaches have to be used for bigfloats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of what I think could be unified for the big float version is the negative integer/nonfinite checks