Skip to content

Commit 7ebc6ca

Browse files
committed
fix: change the argument order
1 parent 2ce9bb3 commit 7ebc6ca

File tree

9 files changed

+63
-27
lines changed

9 files changed

+63
-27
lines changed

src/ClosedFormExpectations.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export ClosedWilliamsProduct
1313
struct ClosedFormExpectation end
1414

1515
"""
16-
mean(::ClosedFormExpectation, q, f)
16+
mean(::ClosedFormExpectation, f, q)
1717
1818
Compute the E_q[f(x)] where q is a distribution and f is a function.
1919
"""
@@ -22,7 +22,7 @@ function mean(::ClosedFormExpectation, ::Nothing, ::Nothing) end
2222
struct ClosedWilliamsProduct end
2323

2424
"""
25-
mean(::ClosedWilliamsProduct, q, f)
25+
mean(::ClosedWilliamsProduct, f, q)
2626
2727
Suppose q is a distribution with density parameterized by θ and f is a function.
2828
@@ -44,4 +44,7 @@ include("expressions/Product.jl")
4444
include("Exponential/Exponential.jl")
4545
include("Normal/UnivariateNormal.jl")
4646

47+
# Logpdf structure
48+
include("logpdf.jl")
49+
4750
end

src/Exponential/Exponential.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,50 @@ import Distributions: Exponential, LogNormal, scale, kldivergence, entropy
22
import Base.MathConstants: eulergamma
33
import LogExpFunctions: xlogx
44

5-
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), Exponential{T}}) where {T}
5+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Exponential{T}}, q::Exponential) where {T}
66
λ1 = mean(q)
77
λ2 = mean(p.inner)
88
return -(λ1 + xlogx(λ2))/λ2
99
end
1010

11-
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), LogNormal{T}}) where {T}
11+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), LogNormal{T}}, q::Exponential) where {T}
1212
μ, σ = p.inner.μ, p.inner.σ
1313
λ = mean(q)
1414
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ))) + eulergamma - log(λ) - 0.5*log(2π) - log(σ)
1515
end
1616

17-
function mean(::ClosedFormExpectation, q::Exponential, p::typeof(log))
17+
function mean(::ClosedFormExpectation, p::typeof(log), q::Exponential)
1818
return -eulergamma + log(mean(q))
1919
end
2020

21-
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), typeof(identity)})
21+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), typeof(identity)}, q::Exponential)
2222
return -eulergamma + log(mean(q))
2323
end
2424

25-
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), ExpLogSquare{T}}) where {T}
25+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), ExpLogSquare{T}}, q::Exponential) where {T}
2626
μ, σ = p.inner.μ, p.inner.σ
2727
λ = mean(q)
2828
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ)))
2929
end
3030

31-
function mean(::ClosedWilliamsProduct, q::Exponential, p::typeof(identity))
31+
function mean(::ClosedWilliamsProduct, p::typeof(identity), q::Exponential)
3232
return 1/mean(q)
3333
end
3434

35-
function mean(::ClosedWilliamsProduct, q::Exponential, p::ExpLogSquare)
35+
function mean(::ClosedWilliamsProduct, p::ExpLogSquare, q::Exponential)
3636
μ = p.μ
3737
σ = p.σ
3838
λ = mean(q)
3939
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ)
4040
end
4141

42-
function mean(::ClosedWilliamsProduct, q::Exponential, p::LogNormal)
42+
function mean(::ClosedWilliamsProduct, p::LogNormal, q::Exponential)
4343
μ, σ = p.μ, p.σ
4444
λ = mean(q)
4545
return 1/(2*σ^2)*(-1/λ*(-2*(eulergamma+μ) + log(λ)) - log(λ)/λ) - 1/λ
4646
end
4747

48-
function mean(::ClosedWilliamsProduct, ::Exponential, p::Exponential)
48+
function mean(::ClosedWilliamsProduct, p::Exponential, ::Exponential)
4949
λ2 = mean(p)
5050
return -1/λ2
5151
end

src/Normal/UnivariateNormal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import Distributions: Laplace, Normal, Rayleigh, params, cdf
22

33

4-
function mean(::ClosedFormExpectation, q::Normal, p::ComposedFunction{typeof(log), Normal{T}}) where {T}
4+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Normal{T}}, q::Normal) where {T}
55
μ_q, σ_q = q.μ, q.σ
66
μ_p, σ_p = p.inner.μ, p.inner.σ
77
return - 1/2 * log(2 * π * σ_p^2) - (σ_q^2 + (μ_p- μ_q)^2) / (2 * σ_p^2)
88
end
99

10-
function mean(::ClosedFormExpectation, q::Normal, p::ComposedFunction{typeof(log), Laplace{T}}) where {T}
10+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Laplace{T}}, q::Normal) where {T}
1111
μ_q, σ_q = q.μ, q.σ
1212
(μ_p, θ_p) = params(p.inner)
1313
normal = Normal(0,σ_q)

src/expressions/Product.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ struct Product{T} <: Expression
44
multipliers::T
55
end
66

7-
function mean(::ClosedFormExpectation, q, p::ComposedFunction{typeof(log), Product{T}}) where {T}
8-
return sum(mean(ClosedFormExpectation(), q, log p_i) for p_i in p.inner.multipliers)
7+
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Product{T}}, q) where {T}
8+
return sum(mean(ClosedFormExpectation(), log p_i, q) for p_i in p.inner.multipliers)
99
end

src/logpdf.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export Logpdf
2+
3+
import Distributions: Distribution, logpdf
4+
5+
struct Logpdf{D}
6+
dist::D
7+
end
8+
9+
function (f::Logpdf{D})(args...) where {D <: Distribution}
10+
return logpdf(f.dist, args...;)
11+
end
12+
13+
# function convert(::Type{Logpdf}, fixed_call::Fix1{typeof{logpdf}, D}) where {D <: Distribution}
14+
# return Logpdf(D)
15+
# end

test/Exponential/mean_tests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
N = 10^6
1212
samples = rand(rng, Exponential(λ), N)
1313
logpdf_samples = logpdf(LogNormal(μ, σ), samples)
14-
expectation = mean(ClosedFormExpectation(), Exponential(λ), log LogNormal(μ, σ))
14+
expectation = mean(ClosedFormExpectation(), log LogNormal(μ, σ), Exponential))
1515
@test sigma_rule(expectation, mean(logpdf_samples), std(logpdf_samples), 10^6)
1616
end
1717
end
@@ -27,7 +27,7 @@ end
2727
N = 10^6
2828
samples = rand(rng, Exponential(λ), 10^6)
2929
log_samples = log.(samples)
30-
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ), log), mean(log_samples), std(log_samples), N)
30+
@test sigma_rule(mean(ClosedFormExpectation(), log, Exponential(λ)), mean(log_samples), std(log_samples), N)
3131
end
3232
end
3333

@@ -46,7 +46,7 @@ end
4646
N = 10^5
4747
samples = rand(rng, Exponential(λ), N)
4848
log_samples = log.(ExpLogSquare(μ, σ).(samples))
49-
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)), mean(log_samples), std(log_samples), N)
49+
@test sigma_rule(mean(ClosedFormExpectation(), log ExpLogSquare(μ, σ), Exponential)), mean(log_samples), std(log_samples), N)
5050
end
5151
end
5252

@@ -63,8 +63,8 @@ end
6363
σ = rand(rng)*10
6464
λ = rand(rng)*10
6565
product = ClosedFormExpectations.Product((ExpLogSquare(μ, σ), identity))
66-
sum_mean = mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)) + mean(ClosedFormExpectation(), Exponential(λ), log)
67-
@test mean(ClosedFormExpectation(), Exponential(λ), log product) sum_mean
66+
sum_mean = mean(ClosedFormExpectation(), log ExpLogSquare(μ, σ), Exponential(λ)) + mean(ClosedFormExpectation(), log, Exponential(λ))
67+
@test mean(ClosedFormExpectation(), log product, Exponential(λ)) sum_mean
6868
end
6969
end
7070

@@ -82,6 +82,6 @@ end
8282
N = 10^6
8383
samples = rand(rng, Exponential(λ1), N)
8484
log_samples = logpdf(Exponential(λ2), samples)
85-
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ1), log Exponential(λ2)), mean(log_samples), std(log_samples), N)
85+
@test sigma_rule(mean(ClosedFormExpectation(), log Exponential(λ2), Exponential(λ1)), mean(log_samples), std(log_samples), N)
8686
end
8787
end

test/Exponential/williams_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
N = 10^6
1717
samples = rand(rng, Exponential(λ), 10^6)
1818
williams_product = map(x -> score(Exponential(λ), x)*log(x), samples)
19-
expectation = mean(ClosedWilliamsProduct(), Exponential(λ), identity)
19+
expectation = mean(ClosedWilliamsProduct(), identity, Exponential(λ))
2020
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
2121
end
2222
end
@@ -42,7 +42,7 @@ end
4242
samples = rand(rng, Exponential(λ), 10^6)
4343
fn(x) = (log ExpLogSquare(μ, σ))(x)
4444
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
45-
expectation = mean(ClosedWilliamsProduct(), Exponential(λ), ExpLogSquare(μ, σ))
45+
expectation = mean(ClosedWilliamsProduct(), ExpLogSquare(μ, σ), Exponential))
4646
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
4747
end
4848
end
@@ -68,7 +68,7 @@ end
6868
samples = rand(rng, Exponential(λ), 10^6)
6969
fn(x) = logpdf(LogNormal(μ, σ), x)
7070
williams_product = map(x -> score(Exponential(λ), x)*fn(x), samples)
71-
expectation = mean(ClosedWilliamsProduct(), Exponential(λ), LogNormal(μ, σ))
71+
expectation = mean(ClosedWilliamsProduct(), LogNormal(μ, σ), Exponential))
7272
@test sigma_rule(expectation, mean(williams_product), std(williams_product), N)
7373
end
7474
end
@@ -90,6 +90,6 @@ end
9090
samples = rand(rng, Exponential(λ1), N)
9191
fn(x) = logpdf(Exponential(λ2), x)
9292
williams_product = map(x -> score(Exponential(λ1), x)*fn(x), samples)
93-
@test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ1), Exponential(λ2)), mean(williams_product), std(williams_product), N)
93+
@test sigma_rule(mean(ClosedWilliamsProduct(), Exponential(λ2), Exponential(λ1)), mean(williams_product), std(williams_product), N)
9494
end
9595
end

test/Normal/mean_test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
N = 10^5
1212
samples = rand(rng, Normal(μ_1, σ_1), N)
1313
log_samples = map(x -> logpdf(Normal(μ_2, σ_2),x),samples)
14-
@test sigma_rule(mean(ClosedFormExpectation(), Normal(μ_1, σ_1), log Normal(μ_2, σ_2)), mean(log_samples), std(log_samples), N)
14+
@test sigma_rule(mean(ClosedFormExpectation(), log Normal(μ_2, σ_2), Normal(μ_1, σ_1)), mean(log_samples), std(log_samples), N)
1515
end
1616
end
1717

@@ -30,6 +30,6 @@ end
3030
laplace = Laplace(μ_2, θ)
3131
samples = rand(rng, normal, N)
3232
log_samples = map(x -> logpdf(laplace,x),samples)
33-
@test sigma_rule(mean(ClosedFormExpectation(), normal, log laplace), mean(log_samples), std(log_samples), N)
33+
@test sigma_rule(mean(ClosedFormExpectation(), log laplace, normal), mean(log_samples), std(log_samples), N)
3434
end
3535
end

test/interface/fix1.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
@testset "Support ::Base.Fix1{typeof(logpdf), D}" begin
2+
using Distributions
3+
using ClosedFormExpectations
4+
using StableRNGs
5+
using Base.MathConstants: eulergamma
6+
7+
include("../test_utils.jl")
8+
rng = StableRNG(123)
9+
for _ in 1:10
10+
μ = rand(rng)*10
11+
σ = rand(rng)*10
12+
λ = rand(rng)*10
13+
N = 10^5
14+
samples = rand(rng, Exponential(λ), N)
15+
log_samples = log.(ExpLogSquare(μ, σ).(samples))
16+
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)), mean(log_samples), std(log_samples), N)
17+
end
18+
end

0 commit comments

Comments
 (0)