Skip to content

Commit

Permalink
fix: update ClosedWilliamsProduct with Logpdf
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Apr 16, 2024
1 parent 60cf063 commit 8a26ce3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
16 changes: 8 additions & 8 deletions src/Exponential/Exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@ function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSq
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ)))
end

function mean(::ClosedWilliamsProduct, p::typeof(identity), q::Exponential)
function mean(::ClosedWilliamsProduct, p::typeof(log), q::Exponential)
return 1/mean(q)
end

function mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)
μ = p.μ
σ = p.σ
function mean(::ClosedWilliamsProduct, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
μ = p.inner.μ
σ = p.inner.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ)
end

function mean(::ClosedWilliamsProduct, p::LogNormal, q::Exponential)
μ, σ = p.μ, p.σ
function mean(::ClosedWilliamsProduct, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
μ, σ = p.dist.μ, p.dist.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
end

function mean(::ClosedWilliamsProduct, p::Exponential, ::Exponential)
λ2 = mean(p)
function mean(::ClosedWilliamsProduct, f::Logpdf{Exponential{T}}, ::Exponential) where {T}
λ2 = mean(f.dist)
return -1/λ2
end
2 changes: 1 addition & 1 deletion src/logpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Distributions: Distribution, logpdf
"""
Logpdf
A structure to represent the logpdf function of a distribution.
The structure to represent the logpdf function of a distribution.
"""
struct Logpdf{D}
dist::D
Expand Down
15 changes: 8 additions & 7 deletions test/Exponential/williams_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::typeof(identity))" begin
@testitem "mean(::ClosedWilliamsProduct, p::log, q::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -16,12 +16,12 @@
N = 10^6
samples = rand(rng, Exponential(λ), 10^6)
williams_product = map(x -> score(Exponential(λ), x)*log(x), samples)
expectation = mean(ClosedWilliamsProduct(), identity, Exponential(λ))
expectation = mean(ClosedWilliamsProduct(), log, Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
end
end

@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::ExpLogSquare)" begin
@testitem "mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -42,12 +42,12 @@ end
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = (log ExpLogSquare(μ, σ))(x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), ExpLogSquare(μ, σ), Exponential(λ))
expectation = mean(ClosedWilliamsProduct(), log ExpLogSquare(μ, σ), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
end
end

@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal)" begin
@testitem "mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal}, q::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -68,7 +68,7 @@ end
samples = rand(rng, Exponential(λ), 10^6)
fn(x) = logpdf(LogNormal(μ, σ), x)
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
expectation = mean(ClosedWilliamsProduct(), LogNormal(μ, σ), Exponential(λ))
expectation = mean(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
end
end
Expand All @@ -90,6 +90,7 @@ end
samples = rand(rng, Exponential(λ1), N)
fn(x) = logpdf(Exponential(λ2), x)
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
@test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ2), Exponential(λ1)), mean(williams_product), std(williams_product), N)
expectation = mean(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
end
end

0 comments on commit 8a26ce3

Please sign in to comment.