Skip to content

Commit 4fd8cc7

Browse files
committed
refactor: LogExpression -> Expression
1 parent 58286f5 commit 4fd8cc7

File tree

5 files changed

+33
-8
lines changed

5 files changed

+33
-8
lines changed

src/ClosedFormExpectations.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ Compute the E_q[f(x) ∇_θ log q(x; θ)] where q is a distribution and f is a f
3030
"""
3131
function mean(::ClosedWilliamsProduct, ::Nothing, ::Nothing) end
3232

33-
abstract type LogExpression end
33+
abstract type Expression end
3434

35-
function (f::ComposedFunction{typeof(log), T})(x) where {T <: LogExpression}
35+
function (f::ComposedFunction{typeof(log), T})(x) where {T <: Expression}
3636
return log(f.inner, x)
3737
end
3838

3939
# expressions
4040
include("expressions/ExpLogSquare.jl")
41+
include("expressions/Product.jl")
4142

4243
# rules for computing expectation of log
4344
include("Exponential/Exponential.jl")

src/Exponential/Exponential.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ function mean(::ClosedFormExpectation, q::Exponential, p::typeof(log))
1111
return -eulergamma + log(mean(q))
1212
end
1313

14+
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), typeof(identity)})
15+
return -eulergamma + log(mean(q))
16+
end
17+
1418
function mean(::ClosedFormExpectation, q::Exponential, p::ComposedFunction{typeof(log), ExpLogSquare{T}}) where {T}
1519
μ, σ = p.inner.μ, p.inner.σ
1620
λ = mean(q)

src/expressions/ExpLogSquare.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
ExpLogSquare is a type that represents the exp(-(log(x) - μ)^2/(2σ^2)) function.
55
"""
6-
struct ExpLogSquare{T} <: LogExpression
6+
struct ExpLogSquare{T} <: Expression
77
μ::T
88
σ::T
99
end

src/expressions/Product.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
struct Product <: LogExpression
2-
multipliers::Tuple{<:LogExpression}
1+
export Product
2+
3+
struct Product{T} <: Expression
4+
multipliers::T
35
end
46

5-
function mean(::ClosedFormExpectation, q, p::ComposedFunction{typeof(log), Product})
6-
return sum(mean(ClosedFormExpectation(), q, log p_i) for p_i in p.multipliers)
7+
function mean(::ClosedFormExpectation, q, p::ComposedFunction{typeof(log), Product{T}}) where {T}
8+
return sum(mean(ClosedFormExpectation(), q, log p_i) for p_i in p.inner.multipliers)
79
end

test/Exponential/mean_tests.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
end
3232
end
3333

34-
@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), ExpLogSquare}" begin
34+
@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), ExpLogSquare})" begin
3535
using Distributions
3636
using ClosedFormExpectations
3737
using StableRNGs
@@ -48,4 +48,22 @@ end
4848
log_samples = log.(ExpLogSquare(μ, σ).(samples))
4949
@test sigma_rule(mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)), mean(log_samples), std(log_samples), N)
5050
end
51+
end
52+
53+
@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), ExpLogSquare x identity}" begin
54+
using Distributions
55+
using ClosedFormExpectations
56+
using StableRNGs
57+
using Base.MathConstants: eulergamma
58+
59+
include("../test_utils.jl")
60+
rng = StableRNG(123)
61+
for _ in 1:10
62+
μ = rand(rng)*10
63+
σ = rand(rng)*10
64+
λ = rand(rng)*10
65+
product = ClosedFormExpectations.Product((ExpLogSquare(μ, σ), identity))
66+
sum_mean = mean(ClosedFormExpectation(), Exponential(λ), log ExpLogSquare(μ, σ)) + mean(ClosedFormExpectation(), Exponential(λ), log)
67+
@test mean(ClosedFormExpectation(), Exponential(λ), log product) sum_mean
68+
end
5169
end

0 commit comments

Comments
 (0)