Skip to content

Commit ce866b0

Browse files
committed
fix: support Base.Fix1(logpdf, dist)
1 parent 7ebc6ca commit ce866b0

File tree

6 files changed

+13
-31
lines changed

6 files changed

+13
-31
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ The package exports the following:
2323
```julia
2424
using ClosedFormExpectations
2525
using Distributions
26-
mean(ClosedFormExpectation(), Exponetial(10), log)
26+
mean(ClosedFormExpectation(), log, Exponetial(10))
2727
```

src/ClosedFormExpectations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ function (f::ComposedFunction{typeof(log), T})(x) where {T <: Expression}
3636
return log(f.inner, x)
3737
end
3838

39+
# Logpdf structure
40+
include("logpdf.jl")
41+
3942
# expressions
4043
include("expressions/ExpLogSquare.jl")
4144
include("expressions/Product.jl")
@@ -44,7 +47,6 @@ include("expressions/Product.jl")
4447
include("Exponential/Exponential.jl")
4548
include("Normal/UnivariateNormal.jl")
4649

47-
# Logpdf structure
48-
include("logpdf.jl")
50+
4951

5052
end

src/Exponential/Exponential.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ import Distributions: Exponential, LogNormal, scale, kldivergence, entropy
22
import Base.MathConstants: eulergamma
33
import LogExpFunctions: xlogx
44

5-
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), Exponential{T}}, q::Exponential) where {T}
5+
function mean(::ClosedFormExpectation, p::Logpdf{Exponential{T}}, q::Exponential) where {T}
66
λ1 = mean(q)
7-
λ2 = mean(p.inner)
7+
λ2 = mean(p.dist)
88
return -(λ1 + xlogx(λ2))/λ2
99
end
1010

11-
function mean(::ClosedFormExpectation, p::ComposedFunction{typeof(log), LogNormal{T}}, q::Exponential) where {T}
12-
μ, σ = p.inner.μ, p.inner.σ
11+
function mean(::ClosedFormExpectation, p::Logpdf{LogNormal{T}}, q::Exponential) where {T}
12+
μ, σ = p.dist.μ, p.dist.σ
1313
λ = mean(q)
1414
return 1/(2*σ^2)*(-+eulergamma)^2 - π^2/6 - log(λ)*(-2*(eulergamma+μ) + log(λ))) + eulergamma - log(λ) - 0.5*log(2π) - log(σ)
1515
end

src/logpdf.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,4 @@ function (f::Logpdf{D})(args...) where {D <: Distribution}
1010
return logpdf(f.dist, args...;)
1111
end
1212

13-
# function convert(::Type{Logpdf}, fixed_call::Fix1{typeof{logpdf}, D}) where {D <: Distribution}
14-
# return Logpdf(D)
15-
# end
13+
mean(::ClosedFormExpectation, f::Base.Fix1{typeof(logpdf), D}, q) where {D} = mean(ClosedFormExpectation(), Logpdf(f.x), q)

test/Exponential/mean_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
N = 10^6
1212
samples = rand(rng, Exponential(λ), N)
1313
logpdf_samples = logpdf(LogNormal(μ, σ), samples)
14-
expectation = mean(ClosedFormExpectation(), log LogNormal(μ, σ), Exponential(λ))
14+
expectation = mean(ClosedFormExpectation(), Logpdf(LogNormal(μ, σ)), Exponential(λ))
1515
@test sigma_rule(expectation, mean(logpdf_samples), std(logpdf_samples), 10^6)
1616
end
1717
end
@@ -68,7 +68,7 @@ end
6868
end
6969
end
7070

71-
@testitem "mean(::Exponential, ::ComposedFunction{typeof(log), Exponential}" begin
71+
@testitem "mean(::ClosedFormExpectation, ::Logpdf{Exponential}, ::Exponential)" begin
7272
using Distributions
7373
using ClosedFormExpectations
7474
using StableRNGs
@@ -82,6 +82,6 @@ end
8282
N = 10^6
8383
samples = rand(rng, Exponential(λ1), N)
8484
log_samples = logpdf(Exponential(λ2), samples)
85-
@test sigma_rule(mean(ClosedFormExpectation(), log Exponential(λ2), Exponential(λ1)), mean(log_samples), std(log_samples), N)
85+
@test sigma_rule(mean(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1)), mean(log_samples), std(log_samples), N)
8686
end
8787
end

test/interface/fix1.jl

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)