Skip to content

Commit 8a26ce3

Browse files
committed
fix: update ClosedWilliamsProduct with Logpdf
1 parent 60cf063 commit 8a26ce3

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

src/Exponential/Exponential.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,24 @@ function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSq
2828
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ)))
2929
end
3030

31-
function mean(::ClosedWilliamsProduct, p::typeof(identity), q::Exponential)
31+
function mean(::ClosedWilliamsProduct, p::typeof(log), q::Exponential)
3232
return 1/mean(q)
3333
end
3434

35-
function mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)
36-
μ = p.μ
37-
σ = p.σ
35+
function mean(::ClosedWilliamsProduct, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
36+
μ = p.inner.μ
37+
σ = p.inner.σ
3838
λ = mean(q)
3939
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ)
4040
end
4141

42-
function mean(::ClosedWilliamsProduct, p::LogNormal, q::Exponential)
43-
μ, σ = p.μ, p.σ
42+
function mean(::ClosedWilliamsProduct, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
43+
μ, σ = p.dist.μ, p.dist.σ
4444
λ = mean(q)
4545
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
4646
end
4747

48-
function mean(::ClosedWilliamsProduct, p::Exponential, ::Exponential)
49-
λ2 = mean(p)
48+
function mean(::ClosedWilliamsProduct, f::Logpdf{Exponential{T}}, ::Exponential) where {T}
49+
λ2 = mean(f.dist)
5050
return -1/λ2
5151
end

src/logpdf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import Distributions: Distribution, logpdf
55
"""
66
Logpdf
77
8-
A structure to represent the logpdf function of a distribution.
8+
The structure to represent the logpdf function of a distribution.
99
"""
1010
struct Logpdf{D}
1111
dist::D

test/Exponential/williams_tests.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::typeof(identity))" begin
1+
@testitem "mean(::ClosedWilliamsProduct, p::log, q::Exponential)" begin
22
using Distributions
33
using ClosedFormExpectations
44
using StableRNGs
@@ -16,12 +16,12 @@
1616
N = 10^6
1717
samples = rand(rng, Exponential(λ), 10^6)
1818
williams_product = map(x -> score(Exponential(λ), x)*log(x), samples)
19-
expectation = mean(ClosedWilliamsProduct(), identity, Exponential(λ))
19+
expectation = mean(ClosedWilliamsProduct(), log, Exponential(λ))
2020
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
2121
end
2222
end
2323

24-
@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::ExpLogSquare)" begin
24+
@testitem "mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)" begin
2525
using Distributions
2626
using ClosedFormExpectations
2727
using StableRNGs
@@ -42,12 +42,12 @@ end
4242
samples = rand(rng, Exponential(λ), 10^6)
4343
fn(x) = (log ExpLogSquare(μ, σ))(x)
4444
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
45-
expectation = mean(ClosedWilliamsProduct(), ExpLogSquare(μ, σ), Exponential(λ))
45+
expectation = mean(ClosedWilliamsProduct(), log ExpLogSquare(μ, σ), Exponential(λ))
4646
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
4747
end
4848
end
4949

50-
@testitem "mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal)" begin
50+
@testitem "mean(::ClosedWilliamsProduct, f::Logpdf{LogNormal}, q::Exponential)" begin
5151
using Distributions
5252
using ClosedFormExpectations
5353
using StableRNGs
@@ -68,7 +68,7 @@ end
6868
samples = rand(rng, Exponential(λ), 10^6)
6969
fn(x) = logpdf(LogNormal(μ, σ), x)
7070
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
71-
expectation = mean(ClosedWilliamsProduct(), LogNormal(μ, σ), Exponential(λ))
71+
expectation = mean(ClosedWilliamsProduct(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
7272
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
7373
end
7474
end
@@ -90,6 +90,7 @@ end
9090
samples = rand(rng, Exponential(λ1), N)
9191
fn(x) = logpdf(Exponential(λ2), x)
9292
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
93-
@test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ2), Exponential(λ1)), mean(williams_product), std(williams_product), N)
93+
expectation = mean(ClosedWilliamsProduct(), Logpdf(Exponential(λ2)), Exponential(λ1))
94+
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
9495
end
9596
end

0 commit comments

Comments
 (0)