diff --git a/src/ClosedFormExpectations.jl b/src/ClosedFormExpectations.jl index 237a13a..720043e 100644 --- a/src/ClosedFormExpectations.jl +++ b/src/ClosedFormExpectations.jl @@ -55,6 +55,8 @@ include("Exponential/Exponential.jl") include("Normal/expectation.jl") include("Normal/williams/normal.jl") include("Normal/williams/normal_mean_variance.jl") +include("Normal/williams/ef_parametrization.jl") + # gamma include("Gamma/Gamma.jl") diff --git a/src/Normal/williams/ef_parametrization.jl b/src/Normal/williams/ef_parametrization.jl new file mode 100644 index 0000000..890060f --- /dev/null +++ b/src/Normal/williams/ef_parametrization.jl @@ -0,0 +1,10 @@ +using StaticArrays +import Distributions: var, mean, std +import ExponentialFamily: ExponentialFamilyDistribution, NormalMeanVariance, getnaturalparameters + +function mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistribution{T}) where {T <: NormalMeanVariance} + η = getnaturalparameters(q) + jacobian = @SMatrix [-inv(2*η[2]) η[1]/(2*η[2]^2); 0.0 (-1/η[2])^(3/2)/(2*sqrt(2))] + normal = Normal(mean(q), std(q)) + return mean(expectation, f, normal)' * jacobian +end \ No newline at end of file diff --git a/test/Normal/williams/ef_parametrization/ef_utils.jl b/test/Normal/williams/ef_parametrization/ef_utils.jl new file mode 100644 index 0000000..19791e4 --- /dev/null +++ b/test/Normal/williams/ef_parametrization/ef_utils.jl @@ -0,0 +1,5 @@ +include("../normal_utils.jl") + +using ExponentialFamily + +score(q::ExponentialFamilyDistribution{NormalMeanVariance}, x) = sufficientstatistics(q, x) .- gradlogpartition(q) \ No newline at end of file diff --git a/test/Normal/williams/ef_parametrization/williams_tests.jl b/test/Normal/williams/ef_parametrization/williams_tests.jl new file mode 100644 index 0000000..3728db5 --- /dev/null +++ b/test/Normal/williams/ef_parametrization/williams_tests.jl @@ -0,0 +1,31 @@ +@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Normal}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin + include("ef_utils.jl") + rng = StableRNG(123) + for _ in 1:10 + μ1, σ1 = rand(rng)*10, rand(rng)*5 + μ2, σ2 = rand(rng)*10, rand(rng)*5 + ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1)) + central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Normal(μ2, σ2)), ef, score) + end +end + +@testitem "mean(::ClosedWilliamsProduct, p::Abs, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin + include("ef_utils.jl") + rng = StableRNG(123) + for _ in 1:10 + μ1, σ1 = rand(rng)*10, rand(rng)*5 + ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1)) + central_limit_theorem_test(ClosedWilliamsProduct(), Abs(), ef, score) + end +end + +@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Laplace}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin + include("ef_utils.jl") + rng = StableRNG(123) + for _ in 1:10 + μ1, σ1 = rand(rng)*10, rand(rng)*5 + loc, θ = rand(rng)*10, rand(rng)*10 + ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1)) + central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Laplace(loc, θ)), ef, score) + end +end \ No newline at end of file diff --git a/test/interface/gauss_tests.jl b/test/interface/gauss_tests.jl index fb01d80..717908c 100644 --- a/test/interface/gauss_tests.jl +++ b/test/interface/gauss_tests.jl @@ -27,13 +27,16 @@ end @testset "ClosedWilliamsProduct interface" begin using Distributions + using ExponentialFamily nmv = NormalMeanVariance(0.0, 1.0) nmp = NormalMeanPrecision(0.0, 1.0) + ef = convert(ExponentialFamilyDistribution, nmv) @test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) isa AbstractArray @test mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) isa AbstractArray @test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) ≈ mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) @test mean(ClosedWilliamsProduct(), Logpdf(nmv), nmv) ≈ mean(ClosedWilliamsProduct(), Logpdf(nmp), nmv) + @test mean(ClosedWilliamsProduct(), Logpdf(nmv), ef) isa AbstractArray end end \ No newline at end of file