Skip to content

Commit

Permalink
Merge pull request #7 from biaslab/add-gamma
Browse files Browse the repository at this point in the history
Add closed form expectation for Gamma
  • Loading branch information
Nimrais authored Apr 18, 2024
2 parents c299c07 + 12e6633 commit d655ae6
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 87 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ version = "0.0.2"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Aqua = "0.8.4"
CpuId = "0.3.1"
Distributions = "0.25"
LogExpFunctions = "0.3"
ReTestItems = "1.23.1"
SpecialFunctions = "2.3.1"
StableRNGs = "1.0.1"
StaticArrays = "1.9.3"
Test = "<0.0.1, 1"
julia = "1.10"

Expand Down
6 changes: 4 additions & 2 deletions src/ClosedFormExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ include("logpdf.jl")
# expressions
include("expressions/ExpLogSquare.jl")
include("expressions/Product.jl")
include("expressions/Square.jl")
include("expressions/Power.jl")
include("expressions/xlog2x.jl")

# rules for computing expectation of log
include("Exponential/Exponential.jl")
include("Normal/UnivariateNormal.jl")


include("Gamma/Gamma.jl")

end
28 changes: 14 additions & 14 deletions src/Exponential/Exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,45 @@ import Distributions: Exponential, LogNormal, scale, kldivergence, entropy
import Base.MathConstants: eulergamma
import LogExpFunctions: xlogx

function mean(::ClosedFormExpectation, p::Logpdf{Exponential{T}}, q::Exponential) where {T}
function mean(::ClosedFormExpectation, f::Logpdf{Exponential{T}}, q::Exponential) where {T}
λ1 = mean(q)
λ2 = mean(p.dist)
λ2 = mean(f.dist)
return -(λ1 + xlogx(λ2))/λ2
end

function mean(::ClosedFormExpectation, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = p.dist.μ, p.dist.σ
function mean(::ClosedFormExpectation, f::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = f.dist.μ, f.dist.σ
λ = mean(q)
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ))) + eulergamma - log(λ) - 0.5*log(2π) - log(σ)
end

function mean(::ClosedFormExpectation, p::typeof(log), q::Exponential)
function mean(::ClosedFormExpectation, ::typeof(log), q::Exponential)
return -eulergamma + log(mean(q))
end

function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), typeof(identity)}, q::Exponential)
function mean(::ClosedFormExpectation, ::ComposedFunction{typeof(log), typeof(identity)}, q::Exponential)
return -eulergamma + log(mean(q))
end

function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ, σ = p.inner.μ, p.inner.σ
function mean(::ClosedFormExpectation, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ, σ = f.inner.μ, f.inner.σ
λ = mean(q)
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ)))
end

function mean(::ClosedWilliamsProduct, p::typeof(log), q::Exponential)
function mean(::ClosedWilliamsProduct, ::typeof(log), q::Exponential)
return 1/mean(q)
end

function mean(::ClosedWilliamsProduct, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ = p.inner.μ
σ = p.inner.σ
function mean(::ClosedWilliamsProduct, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ = f.inner.μ
σ = f.inner.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ)
end

function mean(::ClosedWilliamsProduct, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = p.dist.μ, p.dist.σ
function mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = f.dist.μ, f.dist.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
end
Expand Down
56 changes: 56 additions & 0 deletions src/Gamma/Gamma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using StaticArrays
import SpecialFunctions: trigamma, digamma, polygamma
import Distributions: Gamma, shape, rate
import LogExpFunctions: xlogx

function mean(::ClosedFormExpectation, ::typeof(log), q::Gamma)
return digamma(shape(q)) + log(scale(q))
end

function mean(::ClosedWilliamsProduct, ::typeof(log), q::Gamma)
return @SVector [
polygamma(1, shape(q)),
1/scale(q)
]
end

function mean(::ClosedFormExpectation, ::typeof(xlogx), q::Gamma)
scaler = shape(q)/rate(q)
return scaler * (digamma(shape(q)+1) - log(rate(q)))
end

function mean(::ClosedFormExpectation, ::typeof(xlog2x), q::Gamma)
scaler = shape(q)/rate(q)
return scaler * ((digamma(shape(q)+1) - log(rate(q)))^2 + trigamma(shape(q)+1))
end

function mean(::ClosedFormExpectation, ::ComposedFunction{Square, typeof(log)}, q::Gamma)
return trigamma(shape(q)) + (digamma(shape(q)) - log(rate(q)))^2
end

function mean(::ClosedWilliamsProduct, ::ComposedFunction{Square, typeof(log)}, q::Gamma)
return @SVector [
polygamma(2, shape(q)) + 2 * (digamma(shape(q)) - log(rate(q)))*trigamma(shape(q)),
2 * (digamma(shape(q)) - log(rate(q))) * rate(q)
]
end

function mean(strategy::ClosedFormExpectation, ::ComposedFunction{Power{Val{3}}, typeof(log)}, q::Gamma)
Elogx = mean(strategy, log, q)
Elog2x = mean(strategy, Square() log, q)
return polygamma(2, shape(q)) + 3*Elogx * Elog2x - 2*Elogx^3
end

function mean(strategy::ClosedFormExpectation, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Gamma) where {T}
μ, σ = f.inner.μ, f.inner.σ
Elogx = mean(strategy, log, q)
Elog2x = mean(strategy, Square() log, q)
return -1/(2*σ^2)*^2 - 2*μ*Elogx + Elog2x)
end

function mean(strategy::ClosedFormExpectation, f::Logpdf{LogNormal{T}}, q::Gamma) where {T}
μ, σ = f.dist.μ, f.dist.σ
E_logexplogsquare = mean(strategy, log ExpLogSquare(μ, σ), q)
E_logx = mean(strategy, log, q)
return E_logexplogsquare - E_logx - log(σ) - 0.5*log(2pi)
end
18 changes: 18 additions & 0 deletions src/expressions/Power.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export Power

"""
Power
Power is a type that represents the x^N function.
"""
struct Power{T} <: Expression
n::T
end

function (f::Power{Val{N}})(x) where {N}
return x^N
end

function Base.log(::Power{Val{N}}, x) where {N}
return N * log(x)
end
16 changes: 16 additions & 0 deletions src/expressions/Square.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export Square

"""
Square
Square is a type that represents the x^2 function.
"""
struct Square <: Expression end

function (::Square)(x)
return x^2
end

function Base.log(::Square, x)
return 2 * log(x)
end
14 changes: 14 additions & 0 deletions src/expressions/xlog2x.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export xlog2x

"""
Return `x * log(x)^2` for `x ≥ 0`, handling ``x = 0`` by taking the downward limit.
```jldoctest
julia> xlog2x(0)
0.0
```
"""
function xlog2x(x::Number)
result = x * (log(x))^2
return iszero(x) ? zero(result) : result
end
40 changes: 5 additions & 35 deletions test/Exponential/mean_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
using StableRNGs
include("../test_utils.jl")
rng = StableRNG(123)

for _ in 1:10
μ, σ = rand(rng)*10, rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), N)
logpdf_samples = logpdf(LogNormal(μ, σ), samples)
expectation = mean(ClosedFormExpectation(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
@test sigma_rule(expectation, mean(logpdf_samples), std(logpdf_samples), 10^6)
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
end
end

Expand All @@ -24,10 +20,7 @@ end
rng = StableRNG(123)
for _ in 1:10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
log_samples = log.(samples)
@test sigma_rule(mean(ClosedFormExpectation(), log, Exponential(λ)), mean(log_samples), std(log_samples), N)
central_limit_theorem_test(ClosedFormExpectation(), log, Exponential(λ))
end
end

Expand All @@ -44,27 +37,7 @@ end
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^5
samples = rand(rng, Exponential(λ), N)
log_samples = log.(ExpLogSquare(μ, σ).(samples))
@test sigma_rule(mean(ClosedFormExpectation(), log ExpLogSquare(μ, σ), Exponential(λ)), mean(log_samples), std(log_samples), N)
end
end

@testitem "mean(::ClosedFormExpectation, ::ComposedFunction{typeof(log), ExpLogSquare x identity}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
product = ClosedFormExpectations.Product((ExpLogSquare(μ, σ), identity))
sum_mean = mean(ClosedFormExpectation(), log ExpLogSquare(μ, σ), Exponential(λ)) + mean(ClosedFormExpectation(), log, Exponential(λ))
@test mean(ClosedFormExpectation(), log product, Exponential(λ)) sum_mean
central_limit_theorem_test(ClosedFormExpectation(), log ExpLogSquare(μ, σ), Exponential(λ), N)
end
end

Expand All @@ -79,9 +52,6 @@ end
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
log_samples = logpdf(Exponential(λ2), samples)
@test sigma_rule(mean(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1)), mean(log_samples), std(log_samples), N)
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1))
end
end
35 changes: 8 additions & 27 deletions test/Exponential/williams_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "mean(::ClosedWilliamsProduct, p::log, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::log, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -13,15 +13,11 @@

for _ in 1:10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
williams_product = map(x -> score(Exponential(λ), x)*log(x), samples)
expectation = mean(ClosedWilliamsProduct(), log, Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), log, Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, p::log ∘ ExpLogSquare, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::log ∘ ExpLogSquare, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -38,16 +34,11 @@ end
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = (log ExpLogSquare(μ, σ))(x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), log ExpLogSquare(μ, σ), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), log ExpLogSquare(μ, σ), Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal}, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::Logpdf{LogNormal}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -64,16 +55,11 @@ end
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = logpdf(LogNormal(μ, σ), x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, Exponential, Exponential}" begin
@testitem "mean(::ClosedWilliamsProduct, ::Logpdf{Exponential}, ::Exponential}" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -86,11 +72,6 @@ end
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
fn(x) = logpdf(Exponential(λ2), x)
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1), score)
end
end
Loading

0 comments on commit d655ae6

Please sign in to comment.