From 8a26ce3f64dad2edab09f30f4ff208b12f08b7a1 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 16 Apr 2024 21:12:15 +0200 Subject: [PATCH] fix: update ClosedWilliamsProduct with Logpdf --- src/Exponential/Exponential.jl | 16 ++++++++-------- src/logpdf.jl | 2 +- test/Exponential/williams_tests.jl | 15 ++++++++------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/Exponential/Exponential.jl b/src/Exponential/Exponential.jl index a290926..a62d012 100644 --- a/src/Exponential/Exponential.jl +++ b/src/Exponential/Exponential.jl @@ -28,24 +28,24 @@ function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSq return 1/(2*σ^2)*(-(μ+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ))) end -function mean(::ClosedWilliamsProduct, p::typeof(identity), q::Exponential) +function mean(::ClosedWilliamsProduct, p::typeof(log), q::Exponential) return 1/mean(q) end -function mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential) - μ = p.μ - σ = p.σ +function mean(::ClosedWilliamsProduct, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T} + μ = p.inner.μ + σ = p.inner.σ λ = mean(q) return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) end -function mean(::ClosedWilliamsProduct, p::LogNormal, q::Exponential) - μ, σ = p.μ, p.σ +function mean(::ClosedWilliamsProduct, p::Logpdf{LogNormal{T}}, q::Exponential) where {T} + μ, σ = p.dist.μ, p.dist.σ λ = mean(q) return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ end -function mean(::ClosedWilliamsProduct, p::Exponential, ::Exponential) - λ2 = mean(p) +function mean(::ClosedWilliamsProduct, f::Logpdf{Exponential{T}}, ::Exponential) where {T} + λ2 = mean(f.dist) return -1/λ2 end \ No newline at end of file diff --git a/src/logpdf.jl b/src/logpdf.jl index d5b5cbb..b095a92 100644 --- a/src/logpdf.jl +++ b/src/logpdf.jl @@ -5,7 +5,7 @@ import Distributions: Distribution, logpdf """ Logpdf -A structure to represent the logpdf function of a distribution. +The structure to represent the logpdf function of a distribution. """ struct Logpdf{D} dist::D diff --git a/test/Exponential/williams_tests.jl b/test/Exponential/williams_tests.jl index 0ce4010..bf3ce7c 100644 --- a/test/Exponential/williams_tests.jl +++ b/test/Exponential/williams_tests.jl @@ -1,4 +1,4 @@ -@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::typeof(identity))" begin +@testitem "mean(::ClosedWilliamsProduct, p::log, q::Exponential)" begin using Distributions using ClosedFormExpectations using StableRNGs @@ -16,12 +16,12 @@ N = 10^6 samples = rand(rng, Exponential(λ), 10^6) williams_product = map(x -> score(Exponential(λ), x)*log(x), samples) - expectation = mean(ClosedWilliamsProduct(), identity, Exponential(λ)) + expectation = mean(ClosedWilliamsProduct(), log, Exponential(λ)) @test sigma_rule(expectation, mean(williams_product), std(williams_product), N) end end -@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::ExpLogSquare)" begin +@testitem "mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)" begin using Distributions using ClosedFormExpectations using StableRNGs @@ -42,12 +42,12 @@ end samples = rand(rng, Exponential(λ), 10^6) fn(x) = (log ∘ ExpLogSquare(μ, σ))(x) williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples) - expectation = mean(ClosedWilliamsProduct(), ExpLogSquare(μ, σ), Exponential(λ)) + expectation = mean(ClosedWilliamsProduct(), log ∘ ExpLogSquare(μ, σ), Exponential(λ)) @test sigma_rule(expectation, mean(williams_product), std(williams_product), N) end end -@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal)" begin +@testitem "mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal}, q::Exponential)" begin using Distributions using ClosedFormExpectations using StableRNGs @@ -68,7 +68,7 @@ end samples = rand(rng, Exponential(λ), 10^6) fn(x) = logpdf(LogNormal(μ, σ), x) williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples) - expectation = mean(ClosedWilliamsProduct(), LogNormal(μ, σ), Exponential(λ)) + expectation = mean(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ)) @test sigma_rule(expectation, mean(williams_product), std(williams_product), N) end end @@ -90,6 +90,7 @@ end samples = rand(rng, Exponential(λ1), N) fn(x) = logpdf(Exponential(λ2), x) williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples) - @test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ2), Exponential(λ1)), mean(williams_product), std(williams_product), N) + expectation = mean(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1)) + @test sigma_rule(expectation, mean(williams_product), std(williams_product), N) end end \ No newline at end of file