Skip to content

Commit 9d8069e

Browse files
authored
mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistribution{NormalMeanVariance}) (#18)
1 parent 8c0085b commit 9d8069e

File tree

5 files changed

+51
-0
lines changed

5 files changed

+51
-0
lines changed

src/ClosedFormExpectations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ include("Exponential/Exponential.jl")
5555
include("Normal/expectation.jl")
5656
include("Normal/williams/normal.jl")
5757
include("Normal/williams/normal_mean_variance.jl")
58+
include("Normal/williams/ef_parametrization.jl")
59+
5860
# gamma
5961
include("Gamma/Gamma.jl")
6062

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using StaticArrays
2+
import Distributions: var, mean, std
3+
import ExponentialFamily: ExponentialFamilyDistribution, NormalMeanVariance, getnaturalparameters
4+
5+
function mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistribution{T}) where {T <: NormalMeanVariance}
6+
η = getnaturalparameters(q)
7+
jacobian = @SMatrix [-inv(2*η[2]) η[1]/(2*η[2]^2); 0.0 (-1/η[2])^(3/2)/(2*sqrt(2))]
8+
normal = Normal(mean(q), std(q))
9+
return mean(expectation, f, normal)' * jacobian
10+
end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
include("../normal_utils.jl")
2+
3+
using ExponentialFamily
4+
5+
score(q::ExponentialFamilyDistribution{NormalMeanVariance}, x) = sufficientstatistics(q, x) .- gradlogpartition(q)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Normal}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
2+
include("ef_utils.jl")
3+
rng = StableRNG(123)
4+
for _ in 1:10
5+
μ1, σ1 = rand(rng)*10, rand(rng)*5
6+
μ2, σ2 = rand(rng)*10, rand(rng)*5
7+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
8+
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Normal(μ2, σ2)), ef, score)
9+
end
10+
end
11+
12+
@testitem "mean(::ClosedWilliamsProduct, p::Abs, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
13+
include("ef_utils.jl")
14+
rng = StableRNG(123)
15+
for _ in 1:10
16+
μ1, σ1 = rand(rng)*10, rand(rng)*5
17+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
18+
central_limit_theorem_test(ClosedWilliamsProduct(), Abs(), ef, score)
19+
end
20+
end
21+
22+
@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Laplace}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
23+
include("ef_utils.jl")
24+
rng = StableRNG(123)
25+
for _ in 1:10
26+
μ1, σ1 = rand(rng)*10, rand(rng)*5
27+
loc, θ = rand(rng)*10, rand(rng)*10
28+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
29+
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Laplace(loc, θ)), ef, score)
30+
end
31+
end

test/interface/gauss_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ end
2727

2828
@testset "ClosedWilliamsProduct interface" begin
2929
using Distributions
30+
using ExponentialFamily
3031

3132
nmv = NormalMeanVariance(0.0, 1.0)
3233
nmp = NormalMeanPrecision(0.0, 1.0)
34+
ef = convert(ExponentialFamilyDistribution, nmv)
3335

3436
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) isa AbstractArray
3537
@test mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) isa AbstractArray
3638
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1))
3739
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), nmv) mean(ClosedWilliamsProduct(), Logpdf(nmp), nmv)
40+
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), ef) isa AbstractArray
3841
end
3942
end

0 commit comments

Comments
 (0)