From 12921afa5f15f19cd2c32ff68bbb36514680ef8d Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 15 Apr 2024 20:42:48 +0200 Subject: [PATCH] feat: Exponetial over Exponetial --- Project.toml | 7 ++++--- src/Exponential/Exponential.jl | 16 ++++++++++++++-- test/Exponential/mean_tests.jl | 18 ++++++++++++++++++ test/Exponential/williams_tests.jl | 21 +++++++++++++++++++++ 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 8396707..d0f2e4c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,16 +5,17 @@ version = "1.0.0-DEV" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" [compat] -julia = "1.10" Distributions = "0.25" +julia = "1.10" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +CpuId = "adafc99b-e345-5852-983c-f28acb93d879" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -CpuId = "adafc99b-e345-5852-983c-f28acb93d879" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "ReTestItems", "StableRNGs", "CpuId"] diff --git a/src/Exponential/Exponential.jl b/src/Exponential/Exponential.jl index c9eaad8..d0f07b4 100644 --- a/src/Exponential/Exponential.jl +++ b/src/Exponential/Exponential.jl @@ -1,5 +1,12 @@ -import Distributions: Exponential, LogNormal, scale +import Distributions: Exponential, LogNormal, scale, kldivergence, entropy import Base.MathConstants: eulergamma +import LogExpFunctions: xlogx + +function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), Exponential{T}}) where {T} + λ1 = mean(q) + λ2 = mean(p.inner) + return -(λ1 + xlogx(λ2))/λ2 +end function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), LogNormal{T}}) where {T} μ, σ = p.inner.μ, p.inner.σ @@ -36,4 +43,9 @@ function mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal) μ, σ = p.μ, p.σ λ = mean(q) return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ -end \ No newline at end of file +end + +function mean(::ClosedWilliamsProduct, ::Exponential, p::Exponential) + λ2 = mean(p) + return -1/λ2 +end \ No newline at end of file diff --git a/test/Exponential/mean_tests.jl b/test/Exponential/mean_tests.jl index c8b9f42..76c0da9 100644 --- a/test/Exponential/mean_tests.jl +++ b/test/Exponential/mean_tests.jl @@ -66,4 +66,22 @@ end sum_mean = mean(ClosedFormExpectation(), Exponential(λ), log ∘ ExpLogSquare(μ, σ)) + mean(ClosedFormExpectation(), Exponential(λ), log) @test mean(ClosedFormExpectation(), Exponential(λ), log ∘ product) ≈ sum_mean end +end + +@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), Exponential}" begin + using Distributions + using ClosedFormExpectations + using StableRNGs + using Base.MathConstants: eulergamma + + include("../test_utils.jl") + rng = StableRNG(123) + for _ in 1:10 + λ1 = rand(rng)*10 + λ2 = rand(rng)*10 + N = 10^6 + samples = rand(rng, Exponential(λ1), N) + log_samples = logpdf(Exponential(λ2), samples) + @test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ1), log ∘ Exponential(λ2)), mean(log_samples), std(log_samples), N) + end end \ No newline at end of file diff --git a/test/Exponential/williams_tests.jl b/test/Exponential/williams_tests.jl index fe24e3c..1ac36e8 100644 --- a/test/Exponential/williams_tests.jl +++ b/test/Exponential/williams_tests.jl @@ -71,4 +71,25 @@ end expectation = mean(ClosedWilliamsProduct(), Exponential(λ), LogNormal(μ, σ)) @test sigma_rule(expectation, mean(williams_product), std(williams_product), N) end +end + +@testitem "mean(::ClosedWilliamsProduct, Exponential, Exponential}" begin + using Distributions + using ClosedFormExpectations + using StableRNGs + using Base.MathConstants: eulergamma + + include("../test_utils.jl") + rng = StableRNG(123) + score(q::Exponential, x) = -1/mean(q) + x/(mean(q)^2) + + for _ in 1:10 + λ1 = rand(rng)*10 + λ2 = rand(rng)*10 + N = 10^6 + samples = rand(rng, Exponential(λ1), N) + fn(x) = logpdf(Exponential(λ2), x) + williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples) + @test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ1), Exponential(λ2)), mean(williams_product), std(williams_product), N) + end end \ No newline at end of file