Skip to content

Commit

Permalink
feat: Exponetial over Exponetial
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Apr 15, 2024
1 parent cff79e9 commit 12921af
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 5 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ version = "1.0.0-DEV"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"

[compat]
julia = "1.10"
Distributions = "0.25"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ReTestItems", "StableRNGs", "CpuId"]
16 changes: 14 additions & 2 deletions src/Exponential/Exponential.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import Distributions: Exponential, LogNormal, scale
import Distributions: Exponential, LogNormal, scale, kldivergence, entropy
import Base.MathConstants: eulergamma
import LogExpFunctions: xlogx

function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), Exponential{T}}) where {T}
λ1 = mean(q)
λ2 = mean(p.inner)
return -(λ1 + xlogx(λ2))/λ2
end

function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), LogNormal{T}}) where {T}
μ, σ = p.inner.μ, p.inner.σ
Expand Down Expand Up @@ -36,4 +43,9 @@ function mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal)
μ, σ = p.μ, p.σ
λ = mean(q)
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
end
end

function mean(::ClosedWilliamsProduct, ::Exponential, p::Exponential)
λ2 = mean(p)
return -1/λ2
end
18 changes: 18 additions & 0 deletions test/Exponential/mean_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,22 @@ end
sum_mean = mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)) + mean(ClosedFormExpectation(), Exponential(λ), log)
@test mean(ClosedFormExpectation(), Exponential(λ), log product) sum_mean
end
end

@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), Exponential}" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
log_samples = logpdf(Exponential(λ2), samples)
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ1), log Exponential(λ2)), mean(log_samples), std(log_samples), N)
end
end
21 changes: 21 additions & 0 deletions test/Exponential/williams_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,25 @@ end
expectation = mean(ClosedWilliamsProduct(), Exponential(λ), LogNormal(μ, σ))
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
end
end

@testitem "mean(::ClosedWilliamsProduct, Exponential, Exponential}" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
score(q::Exponential, x) = -1/mean(q) + x/(mean(q)^2)

for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
N = 10^6
samples = rand(rng, Exponential(λ1), N)
fn(x) = logpdf(Exponential(λ2), x)
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
@test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ1), Exponential(λ2)), mean(williams_product), std(williams_product), N)
end
end

0 comments on commit 12921af

Please sign in to comment.