Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closed form expectation for Gamma #7

Merged
merged 7 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ version = "0.0.2"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Aqua = "0.8.4"
CpuId = "0.3.1"
Distributions = "0.25"
LogExpFunctions = "0.3"
ReTestItems = "1.23.1"
SpecialFunctions = "2.3.1"
StableRNGs = "1.0.1"
StaticArrays = "1.9.3"
Test = "<0.0.1, 1"
julia = "1.10"

Expand Down
6 changes: 4 additions & 2 deletions src/ClosedFormExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ include("logpdf.jl")
# expressions
include("expressions/ExpLogSquare.jl")
include("expressions/Product.jl")
include("expressions/Square.jl")
include("expressions/Power.jl")
include("expressions/xlog2x.jl")

# rules for computing expectation of log
include("Exponential/Exponential.jl")
include("Normal/UnivariateNormal.jl")


include("Gamma/Gamma.jl")

end
28 changes: 14 additions & 14 deletions src/Exponential/Exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,45 @@ import Distributions: Exponential, LogNormal, scale, kldivergence, entropy
import Base.MathConstants: eulergamma
import LogExpFunctions: xlogx

function mean(::ClosedFormExpectation, p::Logpdf{Exponential{T}}, q::Exponential) where {T}
function mean(::ClosedFormExpectation, f::Logpdf{Exponential{T}}, q::Exponential) where {T}
λ1 = mean(q)
λ2 = mean(p.dist)
λ2 = mean(f.dist)
return -(λ1 + xlogx(λ2))/λ2
end

function mean(::ClosedFormExpectation, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = p.dist.μ, p.dist.σ
function mean(::ClosedFormExpectation, f::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = f.dist.μ, f.dist.σ
λ = mean(q)
return 1/(2*σ^2)*(-(μ+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ))) + eulergamma - log(λ) - 0.5*log(2π) - log(σ)
end

function mean(::ClosedFormExpectation, p::typeof(log), q::Exponential)
function mean(::ClosedFormExpectation, ::typeof(log), q::Exponential)
return -eulergamma + log(mean(q))
end

function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), typeof(identity)}, q::Exponential)
function mean(::ClosedFormExpectation, ::ComposedFunction{typeof(log), typeof(identity)}, q::Exponential)
return -eulergamma + log(mean(q))
end

function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ, σ = p.inner.μ, p.inner.σ
function mean(::ClosedFormExpectation, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ, σ = f.inner.μ, f.inner.σ
λ = mean(q)
return 1/(2*σ^2)*(-(μ+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ)))
end

function mean(::ClosedWilliamsProduct, p::typeof(log), q::Exponential)
function mean(::ClosedWilliamsProduct, ::typeof(log), q::Exponential)
return 1/mean(q)
end

function mean(::ClosedWilliamsProduct, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ = p.inner.μ
σ = p.inner.σ
function mean(::ClosedWilliamsProduct, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ = f.inner.μ
σ = f.inner.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ)
end

function mean(::ClosedWilliamsProduct, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = p.dist.μ, p.dist.σ
function mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = f.dist.μ, f.dist.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
end
Expand Down
56 changes: 56 additions & 0 deletions src/Gamma/Gamma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using StaticArrays
import SpecialFunctions: trigamma, digamma, polygamma
import Distributions: Gamma, shape, rate
import LogExpFunctions: xlogx

function mean(::ClosedFormExpectation, ::typeof(log), q::Gamma)
return digamma(shape(q)) + log(scale(q))
end

function mean(::ClosedWilliamsProduct, ::typeof(log), q::Gamma)
return @SVector [
polygamma(1, shape(q)),
1/scale(q)
]
end

function mean(::ClosedFormExpectation, ::typeof(xlogx), q::Gamma)
scaler = shape(q)/rate(q)
return scaler * (digamma(shape(q)+1) - log(rate(q)))
end

function mean(::ClosedFormExpectation, ::typeof(xlog2x), q::Gamma)
scaler = shape(q)/rate(q)
return scaler * ((digamma(shape(q)+1) - log(rate(q)))^2 + trigamma(shape(q)+1))
end

function mean(::ClosedFormExpectation, ::ComposedFunction{Square, typeof(log)}, q::Gamma)
return trigamma(shape(q)) + (digamma(shape(q)) - log(rate(q)))^2
end

function mean(::ClosedWilliamsProduct, ::ComposedFunction{Square, typeof(log)}, q::Gamma)
return @SVector [
polygamma(2, shape(q)) + 2 * (digamma(shape(q)) - log(rate(q)))*trigamma(shape(q)),
2 * (digamma(shape(q)) - log(rate(q))) * rate(q)
]
end

function mean(strategy::ClosedFormExpectation, ::ComposedFunction{Power{Val{3}}, typeof(log)}, q::Gamma)
Elogx = mean(strategy, log, q)
Elog2x = mean(strategy, Square() ∘ log, q)
return polygamma(2, shape(q)) + 3*Elogx * Elog2x - 2*Elogx^3
end

function mean(strategy::ClosedFormExpectation, f::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Gamma) where {T}
μ, σ = f.inner.μ, f.inner.σ
Elogx = mean(strategy, log, q)
Elog2x = mean(strategy, Square() ∘ log, q)
return -1/(2*σ^2)*(μ^2 - 2*μ*Elogx + Elog2x)
end

function mean(strategy::ClosedFormExpectation, f::Logpdf{LogNormal{T}}, q::Gamma) where {T}
μ, σ = f.dist.μ, f.dist.σ
E_logexplogsquare = mean(strategy, log ∘ ExpLogSquare(μ, σ), q)
E_logx = mean(strategy, log, q)
return E_logexplogsquare - E_logx - log(σ) - 0.5*log(2pi)
end
18 changes: 18 additions & 0 deletions src/expressions/Power.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export Power

"""
Power

Power is a type that represents the x^N function.
"""
struct Power{T} <: Expression
n::T
end

function (f::Power{Val{N}})(x) where {N}
return x^N
end

function Base.log(::Power{Val{N}}, x) where {N}
return N * log(x)
end
16 changes: 16 additions & 0 deletions src/expressions/Square.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export Square

"""
Square

Square is a type that represents the x^2 function.
"""
struct Square <: Expression end

function (::Square)(x)
return x^2
end

function Base.log(::Square, x)
return 2 * log(x)
end
14 changes: 14 additions & 0 deletions src/expressions/xlog2x.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export xlog2x

"""
Return `x * log(x)^2` for `x ≥ 0`, handling ``x = 0`` by taking the downward limit.

```jldoctest
julia> xlog2x(0)
0.0
```
"""
function xlog2x(x::Number)
result = x * (log(x))^2
return iszero(x) ? zero(result) : result
end
40 changes: 5 additions & 35 deletions test/Exponential/mean_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
using StableRNGs
include("../test_utils.jl")
rng = StableRNG(123)

for _ in 1:10
μ, σ = rand(rng)*10, rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), N)
logpdf_samples = logpdf(LogNormal(μ, σ), samples)
expectation = mean(ClosedFormExpectation(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
@test sigma_rule(expectation, mean(logpdf_samples), std(logpdf_samples), 10^6)
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
end
end

Expand All @@ -24,10 +20,7 @@ end
rng = StableRNG(123)
for _ in 1:10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
log_samples = log.(samples)
@test sigma_rule(mean(ClosedFormExpectation(), log, Exponential(λ)), mean(log_samples), std(log_samples), N)
central_limit_theorem_test(ClosedFormExpectation(), log, Exponential(λ))
end
end

Expand All @@ -44,27 +37,7 @@ end
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^5
samples = rand(rng, Exponential(λ), N)
log_samples = log.(ExpLogSquare(μ, σ).(samples))
@test sigma_rule(mean(ClosedFormExpectation(), log ∘ ExpLogSquare(μ, σ), Exponential(λ)), mean(log_samples), std(log_samples), N)
end
end

@testitem "mean(::ClosedFormExpectation, ::ComposedFunction{typeof(log), ExpLogSquare x identity}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
product = ClosedFormExpectations.Product((ExpLogSquare(μ, σ), identity))
sum_mean = mean(ClosedFormExpectation(), log ∘ ExpLogSquare(μ, σ), Exponential(λ)) + mean(ClosedFormExpectation(), log, Exponential(λ))
@test mean(ClosedFormExpectation(), log ∘ product, Exponential(λ)) ≈ sum_mean
central_limit_theorem_test(ClosedFormExpectation(), log ∘ ExpLogSquare(μ, σ), Exponential(λ), N)
end
end

Expand All @@ -79,9 +52,6 @@ end
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
log_samples = logpdf(Exponential(λ2), samples)
@test sigma_rule(mean(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1)), mean(log_samples), std(log_samples), N)
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1))
end
end
35 changes: 8 additions & 27 deletions test/Exponential/williams_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "mean(::ClosedWilliamsProduct, p::log, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::log, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -13,15 +13,11 @@

for _ in 1:10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
williams_product = map(x -> score(Exponential(λ), x)*log(x), samples)
expectation = mean(ClosedWilliamsProduct(), log, Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), log, Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, p::log ∘ ExpLogSquare, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::log ∘ ExpLogSquare, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -38,16 +34,11 @@ end
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = (log ∘ ExpLogSquare(μ, σ))(x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), log ∘ ExpLogSquare(μ, σ), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), log ∘ ExpLogSquare(μ, σ), Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal}, q::Exponential)" begin
@testitem "mean(::ClosedWilliamsProduct, ::Logpdf{LogNormal}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -64,16 +55,11 @@ end
μ = rand(rng)*10
σ = rand(rng)*10
λ = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = logpdf(LogNormal(μ, σ), x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, Exponential, Exponential}" begin
@testitem "mean(::ClosedWilliamsProduct, ::Logpdf{Exponential}, ::Exponential}" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -86,11 +72,6 @@ end
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
fn(x) = logpdf(Exponential(λ2), x)
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1), score)
end
end
Loading
Loading