diff --git a/src/ClosedFormExpectations.jl b/src/ClosedFormExpectations.jl index 62d3129..237a13a 100644 --- a/src/ClosedFormExpectations.jl +++ b/src/ClosedFormExpectations.jl @@ -58,4 +58,7 @@ include("Normal/williams/normal_mean_variance.jl") # gamma include("Gamma/Gamma.jl") +# exponetial family distribution interface +include("exponential_family_interface.jl") + end diff --git a/src/exponential_family_interface.jl b/src/exponential_family_interface.jl new file mode 100644 index 0000000..898e7d8 --- /dev/null +++ b/src/exponential_family_interface.jl @@ -0,0 +1,6 @@ +import ExponentialFamily: ExponentialFamilyDistribution + +function mean(expectation::ClosedFormExpectation, f, q::ExponentialFamilyDistribution) + dist = convert(Distribution, q) + return mean(expectation, f, dist) +end \ No newline at end of file diff --git a/test/Exponential/mean_tests.jl b/test/Exponential/mean_tests.jl index ef2a5fa..2e2cdb8 100644 --- a/test/Exponential/mean_tests.jl +++ b/test/Exponential/mean_tests.jl @@ -54,4 +54,20 @@ end λ2 = rand(rng)*10 central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1)) end +end + +@testitem "mean(::ClosedFormExpectation, ::Logpdf{Exponential}, ::ExponentialFamilyDistribution{Exponential}" begin + using Distributions + using ClosedFormExpectations + using ExponentialFamily + using StableRNGs + using Base.MathConstants: eulergamma + + include("../test_utils.jl") + rng = StableRNG(123) + for _ in 1:10 + λ1 = rand(rng)*10 + λ2 = rand(rng)*10 + central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), convert(ExponentialFamilyDistribution, Exponential(λ1))) + end end \ No newline at end of file diff --git a/test/interface/exponetial_family_interface_tests.jl b/test/interface/exponetial_family_interface_tests.jl new file mode 100644 index 0000000..55fad93 --- /dev/null +++ b/test/interface/exponetial_family_interface_tests.jl @@ -0,0 +1,20 @@ +@testitem "Support ExponetialFamilyDistribution for ClosedFormExpectation" begin + include("../test_utils.jl") + using ExponentialFamily + @testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Exponential})" begin + dist = Exponential(1.0) + ef = convert(ExponentialFamilyDistribution, Exponential(1.0)) + @test mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), ef) ≈ mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), dist) + end + @testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin + dist = Normal(1.0, 1.0) + ef = convert(ExponentialFamilyDistribution, Normal(1.0, 1.0)) + @test mean(ClosedFormExpectation(), Abs(), ef) ≈ mean(ClosedFormExpectation(), Abs(), dist) + end + @testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Gamma})" begin + import LogExpFunctions: xlogx + dist = Gamma(1.0, 1.0) + ef = convert(ExponentialFamilyDistribution, Gamma(1.0, 1.0)) + @test mean(ClosedFormExpectation(), xlogx, ef) ≈ mean(ClosedFormExpectation(), xlogx, dist) + end +end \ No newline at end of file