Skip to content

Commit

Permalink
Support exponetial family for closed form expectation (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais authored May 8, 2024
1 parent a6133c2 commit 8c0085b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/ClosedFormExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@ include("Normal/williams/normal_mean_variance.jl")
# gamma
include("Gamma/Gamma.jl")

# exponetial family distribution interface
include("exponential_family_interface.jl")

end
6 changes: 6 additions & 0 deletions src/exponential_family_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import ExponentialFamily: ExponentialFamilyDistribution

function mean(expectation::ClosedFormExpectation, f, q::ExponentialFamilyDistribution)
dist = convert(Distribution, q)
return mean(expectation, f, dist)
end
16 changes: 16 additions & 0 deletions test/Exponential/mean_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,20 @@ end
λ2 = rand(rng)*10
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1))
end
end

@testitem "mean(::ClosedFormExpectation, ::Logpdf{Exponential}, ::ExponentialFamilyDistribution{Exponential}" begin
using Distributions
using ClosedFormExpectations
using ExponentialFamily
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
for _ in 1:10
λ1 = rand(rng)*10
λ2 = rand(rng)*10
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), convert(ExponentialFamilyDistribution, Exponential(λ1)))
end
end
20 changes: 20 additions & 0 deletions test/interface/exponetial_family_interface_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testitem "Support ExponetialFamilyDistribution for ClosedFormExpectation" begin
include("../test_utils.jl")
using ExponentialFamily
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Exponential})" begin
dist = Exponential(1.0)
ef = convert(ExponentialFamilyDistribution, Exponential(1.0))
@test mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), ef) mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), dist)
end
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
dist = Normal(1.0, 1.0)
ef = convert(ExponentialFamilyDistribution, Normal(1.0, 1.0))
@test mean(ClosedFormExpectation(), Abs(), ef) mean(ClosedFormExpectation(), Abs(), dist)
end
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Gamma})" begin
import LogExpFunctions: xlogx
dist = Gamma(1.0, 1.0)
ef = convert(ExponentialFamilyDistribution, Gamma(1.0, 1.0))
@test mean(ClosedFormExpectation(), xlogx, ef) mean(ClosedFormExpectation(), xlogx, dist)
end
end

0 comments on commit 8c0085b

Please sign in to comment.