Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ClosedWilliamsProduct for Normal #10

Merged
merged 18 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "ReTestItems", "StableRNGs", "CpuId"]
test = ["Aqua", "CpuId", "Test", "ReTestItems", "StableRNGs"]
2 changes: 1 addition & 1 deletion src/ClosedFormExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ include("expressions/Abs.jl")
# rules for computing expectations
include("Exponential/Exponential.jl")
# normal
include("Normal/UnivariateNormal.jl")
include("Normal/expectation.jl")
include("Normal/williams/normal.jl")

include("Gamma/Gamma.jl")
Expand Down
12 changes: 5 additions & 7 deletions src/Normal/UnivariateNormal.jl → src/Normal/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ using SpecialFunctions: erf
function mean(::ClosedFormExpectation, p::Logpdf{NormalType}, q::GaussianDistributionsFamily) where {NormalType <: GaussianDistributionsFamily}
μ_q, σ_q = mean(q), std(q)
μ_p, σ_p = mean(p.dist), std(p.dist)
return - 1/2 * log(2 * π * σ_p^2) - (σ_q^2 + (μ_p- μ_q)^2) / (2 * σ_p^2)
end
return - 1/2 * log(2 * π * σ_p^2) - (σ_q^2 + (μ_p - μ_q)^2) / (2 * σ_p^2)
end

function mean(::ClosedFormExpectation, p::Logpdf{Laplace{T}}, q::GaussianDistributionsFamily) where {T}
μ_q, σ_q = mean(q), std(q)
(μ_p, θ_p) = params(p.dist)
normal = Normal(0,σ_q)
diff = μ_p - μ_q
return - log(2*θ_p) - θ_p^(-1) * ( 2 * (σ_q/sqrt(2*π)) * exp(-diff^2/(2*σ_q^2)) + diff * (2 * cdf(normal,diff) - 1) )
(loc, θ_p) = params(p.dist)
normal = Normal(mean(q) - loc, std(q))
return -log(2*θ_p) - θ_p^(-1) * mean(ClosedFormExpectation(), Abs(), normal)
end

function mean(::ClosedFormExpectation, f::Abs, q::GaussianDistributionsFamily)
Expand Down
19 changes: 19 additions & 0 deletions src/Normal/williams/normal.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
import Distributions: Normal, std
import ExponentialFamily: GaussianDistributionsFamily
using StaticArrays
using SpecialFunctions: erfc, erf

function mean(::ClosedWilliamsProduct, p::Abs, q::Normal)
μ, σ = q.μ, q.σ
return @SVector [
erf(μ/(sqrt(2)*σ)),
sqrt(2/π)*exp(-μ^2/(2*σ^2))
]
end

function mean(::ClosedWilliamsProduct, p::Logpdf{NormalType}, q::Normal{T}) where {T, NormalType <: GaussianDistributionsFamily}
μ_q, σ_q = mean(q), std(q)
μ_p, σ_p = mean(p.dist), std(p.dist)
return @SVector [
(μ_p - μ_q)/σ_p^2,
-σ_q/σ_p^2
]
end

function mean(::ClosedWilliamsProduct, p::Logpdf{Laplace{T}}, q::Normal{T}) where {T}
(loc, θ) = params(p.dist)
normal = Normal(mean(q) - loc, std(q))
abs_mean = mean(ClosedWilliamsProduct(), Abs(), normal)
return -1/θ * abs_mean
end
2 changes: 0 additions & 2 deletions test/Gamma/gamma_utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
include("../test_utils.jl")

using ClosedFormExpectations
using StableRNGs
using SpecialFunctions
using Distributions

score(q::Gamma, x) = [log(x) - log(scale(q)) - polygamma(0, shape(q)), x/scale(q)^2 - shape(q)/scale(q)]

5 changes: 2 additions & 3 deletions test/Normal/mean_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ end

include("../test_utils.jl")
rng = StableRNG(123)
N = 10^5
for _ in 1:10
μ_1, σ = rand(rng)*10, rand(rng)*5
μ_2, θ = rand(rng)*10, rand(rng)*5
μ_1, σ = rand(rng)*10, rand(rng)*5
μ_2, θ = rand(rng)*10, rand(rng)*5
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Laplace(μ_2, θ)), Normal(μ_1, σ))
end
end
Expand Down
8 changes: 8 additions & 0 deletions test/Normal/normal_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
include("../test_utils.jl")

using StableRNGs
using ClosedFormExpectations
using Distributions

score(q::Normal, x) = [-(x - q.μ)/q.σ^2, -1/q.σ + (x - q.μ)^2/q.σ^3]

34 changes: 33 additions & 1 deletion test/Normal/williams/normal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,36 @@
μ, σ = rand(rng)*10, rand(rng)*10
central_limit_theorem_test(ClosedWilliamsProduct(), Abs(), Normal(μ, σ), score, 10^5)
end
end
end

@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Normal}, q::Normal)" begin
include("normal_utils.jl")
rng = StableRNG(123)
for _ in 1:10
μ1, σ1 = rand(rng)*10, rand(rng)*5
μ2, σ2 = rand(rng)*10, rand(rng)*5
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Normal(μ1, σ1)), Normal(μ2, σ2), score)
end
end

@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Laplace}, q::Normal)" begin
include("normal_utils.jl")

rng = StableRNG(123)

for _ in 1:10
μ, σ = rand(rng)*10, rand(rng)*5
loc, θ = rand(rng)*10, rand(rng)*10
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Laplace(loc, θ)), Normal(μ, σ), score)
end

@testset "compare Logpdf(Laplace) gradient with Abs gradient" begin
for _ in 1:10
μ, σ = rand(rng)*10, rand(rng)*5
θ = rand(rng)*10
williams_result_abs = -1/θ*mean(ClosedWilliamsProduct(), Abs(), Normal(μ, σ))
williams_result_laplace = mean(ClosedWilliamsProduct(), Logpdf(Laplace(0, θ)), Normal(μ, σ))
@test williams_result_abs ≈ williams_result_laplace
end
end
end
3 changes: 1 addition & 2 deletions test/interface/fix1_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
@testitem "Support ::Base.Fix1{typeof(logpdf), D}" begin
include("../test_utils.jl")
using Distributions
using ClosedFormExpectations
using StableRNGs
using Base.MathConstants: eulergamma

include("../test_utils.jl")
rng = StableRNG(123)
dist = Exponential(1.0)
fixed_logpdf = Base.Fix1(logpdf, dist)
Expand Down
38 changes: 38 additions & 0 deletions test/interface/gauss_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testitem "mean(::ClosedFormExpectation, SMT, ::NormalMeanVariance)" begin
include("../test_utils.jl")
using ClosedFormExpectations
using StableRNGs
using ExponentialFamily
using Distributions

nmv = NormalMeanVariance(0.0, 1.0)
normal = Normal(0.0, 1.0)
laplace = Laplace(0.0, 1.0)
@test mean(ClosedFormExpectation(), Logpdf(nmv), nmv) ≈ mean(ClosedFormExpectation(), Logpdf(normal), normal)
@test mean(ClosedFormExpectation(), Logpdf(laplace), nmv) ≈ mean(ClosedFormExpectation(), Logpdf(laplace), normal)
end

@testitem "Support GaussianDistributionsFamily" begin
using ExponentialFamily
using ClosedFormExpectations
using StableRNGs

@testset "ClosedFormExpectation interface" begin
nmv = NormalMeanVariance(0.0, 1.0)
nmp = NormalMeanPrecision(0.0, 1.0)

@test mean(ClosedFormExpectation(), Logpdf(nmv), nmv) isa Number
@test mean(ClosedFormExpectation(), Logpdf(nmp), nmv) isa Number
end

@testset "ClosedWilliamsProduct interface" begin
using Distributions

nmv = NormalMeanVariance(0.0, 1.0)
nmp = NormalMeanPrecision(0.0, 1.0)

@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) isa AbstractArray
@test mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) isa AbstractArray
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) ≈ mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1))
end
end
28 changes: 0 additions & 28 deletions test/interface/support_ef_interface_tests.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ClosedFormExpectations
import Distributions: Normal, Laplace, std
import Distributions: Normal, Laplace, Gamma, LogNormal, std, shape, scale
using StableRNGs

function sigma_rule(expectation, mean, std, N)::Bool
Expand Down