Skip to content

Commit

Permalink
fix: Univariate Normal Logpdf update
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Apr 16, 2024
1 parent 6557f93 commit 237852f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/Normal/UnivariateNormal.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import Distributions: Laplace, Normal, Rayleigh, params, cdf


function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Normal{T}}, q::Normal) where {T}
function mean(::ClosedFormExpectation, p::Logpdf{Normal{T}}, q::Normal) where {T}
μ_q, σ_q = q.μ, q.σ
μ_p, σ_p = p.inner.μ, p.inner.σ
μ_p, σ_p = p.dist.μ, p.dist.σ
return - 1/2 * log(2 * π * σ_p^2) - (σ_q^2 + (μ_p- μ_q)^2) / (2 * σ_p^2)
end

function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Laplace{T}}, q::Normal) where {T}
function mean(::ClosedFormExpectation, p::Logpdf{Laplace{T}}, q::Normal) where {T}
μ_q, σ_q = q.μ, q.σ
(μ_p, θ_p) = params(p.inner)
(μ_p, θ_p) = params(p.dist)
normal = Normal(0,σ_q)
diff = μ_p - μ_q
return - log(2*θ_p) - θ_p^(-1) * ( 2 * (σ_q/sqrt(2*π)) * exp(-diff^2/(2*σ_q^2)) + diff * (2 * cdf(normal,diff) - 1) )
Expand Down
8 changes: 4 additions & 4 deletions test/Exponential/mean_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), LogNormal})" begin
@testitem "mean(::ClosedFormExpectation, ::Logpdf{LogNormal}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -16,7 +16,7 @@
end
end

@testitem "mean(::Exponential, ::typeof{log})" begin
@testitem "mean(::ClosedFormExpectation, ::typeof{log}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -31,7 +31,7 @@ end
end
end

@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), ExpLogSquare})" begin
@testitem "mean(::ClosedFormExpectations, ::ComposedFunction{typeof(log), ExpLogSquare}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -50,7 +50,7 @@ end
end
end

@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), ExpLogSquare x identity}" begin
@testitem "mean(::ClosedFormExpectation, ::ComposedFunction{typeof(log), ExpLogSquare x identity}, ::Exponential)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand Down
8 changes: 4 additions & 4 deletions test/Normal/mean_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "mean(::Normal, ::Normal)" begin
@testitem "mean(::ClosedFormExpectation, ::Logpdf{Normal}, ::Normal)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -11,11 +11,11 @@
N = 10^5
samples = rand(rng, Normal(μ_1, σ_1), N)
log_samples = map(x -> logpdf(Normal(μ_2, σ_2),x),samples)
@test sigma_rule(mean(ClosedFormExpectation(), log Normal(μ_2, σ_2), Normal(μ_1, σ_1)), mean(log_samples), std(log_samples), N)
@test sigma_rule(mean(ClosedFormExpectation(), Logpdf(Normal(μ_2, σ_2)), Normal(μ_1, σ_1)), mean(log_samples), std(log_samples), N)
end
end

@testitem "meanlog(::Normal, ::Laplace)" begin
@testitem "mean(::ClosedFormExpectation, ::Logpdf{Laplace}, ::Normal)" begin
using Distributions
using ClosedFormExpectations
using StableRNGs
Expand All @@ -30,6 +30,6 @@ end
laplace = Laplace(μ_2, θ)
samples = rand(rng, normal, N)
log_samples = map(x -> logpdf(laplace,x),samples)
@test sigma_rule(mean(ClosedFormExpectation(), log laplace, normal), mean(log_samples), std(log_samples), N)
@test sigma_rule(mean(ClosedFormExpectation(), Logpdf(laplace), normal), mean(log_samples), std(log_samples), N)
end
end

0 comments on commit 237852f

Please sign in to comment.