Skip to content

Commit

Permalink
mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistr…
Browse files Browse the repository at this point in the history
…ibution{NormalMeanVariance}) (#18)
  • Loading branch information
Nimrais authored May 8, 2024
1 parent 8c0085b commit 9d8069e
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/ClosedFormExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ include("Exponential/Exponential.jl")
include("Normal/expectation.jl")
include("Normal/williams/normal.jl")
include("Normal/williams/normal_mean_variance.jl")
include("Normal/williams/ef_parametrization.jl")

# gamma
include("Gamma/Gamma.jl")

Expand Down
10 changes: 10 additions & 0 deletions src/Normal/williams/ef_parametrization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using StaticArrays
import Distributions: var, mean, std
import ExponentialFamily: ExponentialFamilyDistribution, NormalMeanVariance, getnaturalparameters

function mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistribution{T}) where {T <: NormalMeanVariance}
η = getnaturalparameters(q)
jacobian = @SMatrix [-inv(2*η[2]) η[1]/(2*η[2]^2); 0.0 (-1/η[2])^(3/2)/(2*sqrt(2))]
normal = Normal(mean(q), std(q))
return mean(expectation, f, normal)' * jacobian
end
5 changes: 5 additions & 0 deletions test/Normal/williams/ef_parametrization/ef_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
include("../normal_utils.jl")

using ExponentialFamily

score(q::ExponentialFamilyDistribution{NormalMeanVariance}, x) = sufficientstatistics(q, x) .- gradlogpartition(q)
31 changes: 31 additions & 0 deletions test/Normal/williams/ef_parametrization/williams_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Normal}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
include("ef_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ1, σ1 = rand(rng)*10, rand(rng)*5
μ2, σ2 = rand(rng)*10, rand(rng)*5
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Normal(μ2, σ2)), ef, score)
end
end

@testitem "mean(::ClosedWilliamsProduct, p::Abs, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
include("ef_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ1, σ1 = rand(rng)*10, rand(rng)*5
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
central_limit_theorem_test(ClosedWilliamsProduct(), Abs(), ef, score)
end
end

@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Laplace}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
include("ef_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ1, σ1 = rand(rng)*10, rand(rng)*5
loc, θ = rand(rng)*10, rand(rng)*10
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Laplace(loc, θ)), ef, score)
end
end
3 changes: 3 additions & 0 deletions test/interface/gauss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ end

@testset "ClosedWilliamsProduct interface" begin
using Distributions
using ExponentialFamily

nmv = NormalMeanVariance(0.0, 1.0)
nmp = NormalMeanPrecision(0.0, 1.0)
ef = convert(ExponentialFamilyDistribution, nmv)

@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) isa AbstractArray
@test mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) isa AbstractArray
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1))
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), nmv) mean(ClosedWilliamsProduct(), Logpdf(nmp), nmv)
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), ef) isa AbstractArray
end
end

2 comments on commit 9d8069e

@Nimrais
Copy link
Collaborator Author

@Nimrais Nimrais commented on 9d8069e May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.1.0 already exists

Please sign in to comment.